Commit 02fd7341 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add optional accumulation step

parent 9a09107c
...@@ -94,9 +94,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -94,9 +94,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
...@@ -363,6 +363,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -363,6 +363,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
def set_last_step(self, last_step): def set_last_step(self, last_step):
self._last_step = last_step self._last_step = last_step
...@@ -492,58 +494,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -492,58 +494,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions if not self._is_accumulation_step:
if param.dtype == torch.float16: # handle overlapped reductions
self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) ) if param.dtype == torch.float16:
else: self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) ) else:
self._grads_generated[param_i]=True self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
if self._overlap_reductions and not self._last_step: self._grads_generated[param_i]=True
flush_block = self._get_flush_block() if self._overlap_reductions and not self._last_step:
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
while flush_block:
def set_global_scale(self, global_scale): block_id = flush_block[0] // self._block_size
"""Set global scale. self._pipeline_block_reductions(block_id)
""" flush_block = self._get_flush_block()
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment