Commit 7ba6a038 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add option to skip overflow check in step() method

parent c7b34549
......@@ -538,7 +538,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True)
def step(self, closure=None):
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
......@@ -560,8 +560,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow
# Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if self.peek_overflow:
if skip_overflow_check:
has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow:
print("Reverting step")
self.revert_step()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment