"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "0e4b7a3929e12d1645e3e177148d15cd4cdec793"
Unverified Commit fcae8fa3 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

porting GradScaler (#1220)


Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
parent 35336133
from apex.transformer import amp
from apex.transformer import functional from apex.transformer import functional
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer import pipeline_parallel from apex.transformer import pipeline_parallel
...@@ -9,6 +10,7 @@ from apex.transformer.enums import AttnMaskType ...@@ -9,6 +10,7 @@ from apex.transformer.enums import AttnMaskType
__all__ = [ __all__ = [
"amp",
"functional", "functional",
"parallel_state", "parallel_state",
"pipeline_parallel", "pipeline_parallel",
......
from apex.transformer.amp.grad_scaler import GradScaler
__all__ = [
"GradScaler",
]
from collections import defaultdict
import torch
from apex.transformer import parallel_state
class GradScaler(torch.cuda.amp.GradScaler):
"""
Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel
ranks in (1) executing optimizer step and (2) gradient scaler update.
"""
def __init__(
self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True
):
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())])
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if found_inf.item() == 0:
retval = optimizer.step(*args, **kwargs)
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf = found_infs[i]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
found_inf_combined += found_inf
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment