"...relaxation/git@developer.sourcefind.cn:nivren/ict-csp.git" did not exist on "273418eef006aaa5cdc9e81707d20aaf858797fe"
Unverified Commit f06e2d85 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Sequence-parallel amax reduction fix (#74)



* Fix no reduce_amax option for SP case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add warning about overriding reduce_amax
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7324fe2b
......@@ -493,6 +493,7 @@ def reduce_tensor_across_group_op_max(
def global_amax_reduction(
fp8_meta: Dict[str, Any],
reduce_amax: bool = False,
reduce_amax_across_tp_group: bool = False,
tp_group: Optional[dist_group_type] = None,
forward: bool = True,
......@@ -508,7 +509,8 @@ def global_amax_reduction(
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
if reduce_amax:
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
if reduce_amax_across_tp_group:
reduce_tensor_across_group_op_max(contiguous_amax, tp_group)
......
......@@ -109,7 +109,7 @@ def _prepare_backward(fp8: bool,
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
if not (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
......@@ -125,14 +125,16 @@ def _prepare_backward(fp8: bool,
with torch.cuda.nvtx.range(name + " backward"):
yield
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
if fp8 and (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta,
fp8_meta["recipe"].reduce_amax,
reduce_amax_across_tp_group,
tp_group,
forward=False,
)
delete_key_from_amax_buffer(forward=False)
class _NoopCat(torch.autograd.Function):
......@@ -454,10 +456,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
reduce_amax = self.fp8_meta["recipe"].reduce_amax or self.sequence_parallel
if self.fp8 and self.sequence_parallel and not self.fp8_meta["recipe"].reduce_amax:
warnings.warn(
"Amax reduction across tensor parallel group is necessary "
"when using sequence parallelism with FP8."
)
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
if reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
......@@ -470,7 +478,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
if reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
......@@ -501,11 +509,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
if self.fp8 and self.training and reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.fp8_meta["recipe"].reduce_amax,
self.sequence_parallel,
self.tp_group,
forward=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment