Unverified Commit f06e2d85 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Sequence-parallel amax reduction fix (#74)



* Fix no reduce_amax option for SP case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add warning about overriding reduce_amax
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7324fe2b
...@@ -493,6 +493,7 @@ def reduce_tensor_across_group_op_max( ...@@ -493,6 +493,7 @@ def reduce_tensor_across_group_op_max(
def global_amax_reduction( def global_amax_reduction(
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
reduce_amax: bool = False,
reduce_amax_across_tp_group: bool = False, reduce_amax_across_tp_group: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
forward: bool = True, forward: bool = True,
...@@ -508,6 +509,7 @@ def global_amax_reduction( ...@@ -508,6 +509,7 @@ def global_amax_reduction(
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]] chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key]) contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
if reduce_amax:
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"]) reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
if reduce_amax_across_tp_group: if reduce_amax_across_tp_group:
reduce_tensor_across_group_op_max(contiguous_amax, tp_group) reduce_tensor_across_group_op_max(contiguous_amax, tp_group)
......
...@@ -109,7 +109,7 @@ def _prepare_backward(fp8: bool, ...@@ -109,7 +109,7 @@ def _prepare_backward(fp8: bool,
"""Checks and prep for BWD.""" """Checks and prep for BWD."""
if fp8: if fp8:
# Update amax and scale; Skip all setup for global amax reduction # Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax: if not (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
else: else:
# From previous iteration # From previous iteration
...@@ -125,12 +125,14 @@ def _prepare_backward(fp8: bool, ...@@ -125,12 +125,14 @@ def _prepare_backward(fp8: bool,
with torch.cuda.nvtx.range(name + " backward"): with torch.cuda.nvtx.range(name + " backward"):
yield yield
if not fp8 or not fp8_meta["recipe"].reduce_amax: if fp8 and (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
return
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
global_amax_reduction( global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False fp8_meta,
fp8_meta["recipe"].reduce_amax,
reduce_amax_across_tp_group,
tp_group,
forward=False,
) )
delete_key_from_amax_buffer(forward=False) delete_key_from_amax_buffer(forward=False)
...@@ -454,10 +456,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -454,10 +456,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
reduce_amax = self.fp8_meta["recipe"].reduce_amax or self.sequence_parallel
if self.fp8 and self.sequence_parallel and not self.fp8_meta["recipe"].reduce_amax:
warnings.warn(
"Amax reduction across tensor parallel group is necessary "
"when using sequence parallelism with FP8."
)
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
...@@ -470,7 +478,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -470,7 +478,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
...@@ -501,11 +509,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -501,11 +509,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
restore_fp8_meta_tensors(self.fp8_meta) restore_fp8_meta_tensors(self.fp8_meta)
return return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: if self.fp8 and self.training and reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
global_amax_reduction, global_amax_reduction,
self.fp8_meta, self.fp8_meta,
self.fp8_meta["recipe"].reduce_amax,
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
forward=True, forward=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment