Commit 7ba6a038 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add option to skip overflow check in step() method

parent c7b34549
...@@ -538,7 +538,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -538,7 +538,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True) self._partial_step_single_shard(block_id, undo=True)
def step(self, closure=None): def step(self, closure=None, skip_overflow_check=False):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
...@@ -560,8 +560,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -560,8 +560,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size) if skip_overflow_check:
if self.peek_overflow: has_overflow = False
else:
self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
has_overflow = self.peek_overflow
if has_overflow:
print("Reverting step") print("Reverting step")
self.revert_step() self.revert_step()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment