@@ -16,19 +16,11 @@ from colossalai.utils import clip_grad_norm_fp32
classTorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate pytorch amp with an optimizer
:param optim: a normal optimizer like Adam or SGD
:type optim: torch.optim.Optimizer
:param init_scale: Initial scale factor
:type init_scale: float, optional, default=2.**16
:param growth_factor: Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
:type growth_factor: float, optional, default=2.0
:param backoff_factor: Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration.
:param growth_interval: Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``.
:param enabled: If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops.
:type enabled: bool, optional, default=True
:param optim: A normal optimizer like Adam or SGD
:param args: Args used to initialize gradient scaler
:param kwargs: Kwargs used to initialize gradient scaler
:type optim: torch.optim.Optimizer
"""
def__init__(self,optim:Optimizer,*args,**kwargs):
...
...
@@ -36,23 +28,25 @@ class TorchAMPOptimizer(ColossalaiOptimizer):