Commit 91a5a87e authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Slight improvements

parent 25c80afe
......@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
......@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _get_flush_block(self):
flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]:
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
......@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block
def _pipeline_block_reductions(self, block_id):
......@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
......@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
......@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
......@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
......@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
if skip_overflow_check:
has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
......
......@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
......@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
def _get_flush_block(self):
flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]:
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
......@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block
def _pipeline_block_reductions(self, block_id):
......@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
......@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
......@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
......@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
......@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
if skip_overflow_check:
has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
print("Reverting step")
self.revert_step()
else:
# Copy self._new_params to model params
......
......@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
......@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
def _get_flush_block(self):
flush_block = []
if self._grads_generated[self._low_param_i[self._current_block-1]]:
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
......@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
if self._current_block == 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g):
......@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2)
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
......@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2)
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
......@@ -313,20 +308,17 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if closure is not None:
loss = closure()
if not self.has_overflow:
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
else:
print("Overflow detected, skipping step")
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment