Commit 91a5a87e authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Slight improvements

parent 25c80afe
...@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported." assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
...@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _get_flush_block(self): def _get_flush_block(self):
flush_block = [] flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]: if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated) num_grads = len(self._grads_generated)
contiguous_idx = num_grads contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]: while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
...@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size end = (self._current_block+1) * self._block_size
flush_block = [start, end] flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block return flush_block
def _pipeline_block_reductions(self, block_id): def _pipeline_block_reductions(self, block_id):
...@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt() self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale combined_scale = self._global_scale
...@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method. """Check if overflows were detected by any call to step(...) method.
Clears the overflow flag. Clears the overflow flag.
""" """
has_overflow = self._overflow_buf.item() has_overflow = self._has_overflow
self._overflow_buf.zero_() self._has_overflow = False
return has_overflow return has_overflow
@property @property
...@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method. """Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag. Does not clear overflow flag.
""" """
return self._overflow_buf.item() return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True): def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow. """Strided check for overflow.
...@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
out_p, out_p,
stride, stride,
1 if clear else 0) 1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
...@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
if skip_overflow_check: has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow: if has_overflow:
print("Reverting step")
self.revert_step() self.revert_step()
else: else:
# Copy self._new_params to model params # Copy self._new_params to model params
......
...@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported." assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
...@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
def _get_flush_block(self): def _get_flush_block(self):
flush_block = [] flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]: if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated) num_grads = len(self._grads_generated)
contiguous_idx = num_grads contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]: while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
...@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size end = (self._current_block+1) * self._block_size
flush_block = [start, end] flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block return flush_block
def _pipeline_block_reductions(self, block_id): def _pipeline_block_reductions(self, block_id):
...@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt() self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale combined_scale = self._global_scale
...@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method. """Check if overflows were detected by any call to step(...) method.
Clears the overflow flag. Clears the overflow flag.
""" """
has_overflow = self._overflow_buf.item() has_overflow = self._has_overflow
self._overflow_buf.zero_() self._has_overflow = False
return has_overflow return has_overflow
@property @property
...@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method. """Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag. Does not clear overflow flag.
""" """
return self._overflow_buf.item() return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True): def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow. """Strided check for overflow.
...@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
out_p, out_p,
stride, stride,
1 if clear else 0) 1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
...@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer): ...@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
if skip_overflow_check: has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow: if has_overflow:
print("Reverting step")
self.revert_step() self.revert_step()
else: else:
# Copy self._new_params to model params # Copy self._new_params to model params
......
...@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size() self._world_size = torch.distributed.get_world_size()
...@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
def _get_flush_block(self): def _get_flush_block(self):
flush_block = [] flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]: if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated) num_grads = len(self._grads_generated)
contiguous_idx = num_grads contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]: while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
...@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size end = (self._current_block+1) * self._block_size
flush_block = [start, end] flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self, p, p_copy, m, v, g):
...@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if block_id == 0: if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st) self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st): with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2) self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale): def set_global_scale(self, global_scale):
...@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
torch.distributed.all_reduce(self._flat_grads) torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st) self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st): with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2) self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
...@@ -313,7 +308,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -313,7 +308,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if not self.has_overflow:
with torch.cuda.stream(self._dwu_st): with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel( self.__launch_step_kernel(
self._fp32_p, self._fp32_p,
...@@ -325,8 +319,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer): ...@@ -325,8 +319,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
for p in self._model_params: self.state[p]['step'] += 1 for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st) torch.cuda.current_stream().wait_stream(self._dwu_st)
else:
print("Overflow detected, skipping step")
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment