Commit 02fd7341 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add optional accumulation step

parent 9a09107c
...@@ -94,9 +94,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -94,9 +94,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
...@@ -363,6 +363,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -363,6 +363,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
def set_last_step(self, last_step): def set_last_step(self, last_step):
self._last_step = last_step self._last_step = last_step
...@@ -492,6 +494,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -492,6 +494,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
if not self._is_accumulation_step:
# handle overlapped reductions # handle overlapped reductions
if param.dtype == torch.float16: if param.dtype == torch.float16:
self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) ) self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
...@@ -505,46 +508,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -505,46 +508,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment