@@ -18,6 +18,17 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
...
@@ -18,6 +18,17 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
:param optim: a normal optimizer like Adam or SGD
:param optim: a normal optimizer like Adam or SGD
:type optim: torch.optim.Optimizer
:type optim: torch.optim.Optimizer
:param init_scale: Initial scale factor
:type init_scale: float, optional, default=2.**16
:param growth_factor: Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
:type growth_factor: float, optional, default=2.0
:param backoff_factor: Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration.
:param growth_interval: Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``.
:param enabled: If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops.