Unverified Commit dddacd2d authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[hotfix] add norm clearing for the overflow step (#2416)

parent 57b6157b
......@@ -140,6 +140,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
......@@ -201,6 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment