"csrc/engine/vscode:/vscode.git/clone" did not exist on "1f5ab1c5f0a99fd2e03ad88fccc6c9112b781791"
Unverified Commit d8a2f352 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove redundant AR for SP case (#79)



* Remove redundant amax AR for SP case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* update advanced docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 97b344cd
......@@ -132,7 +132,7 @@
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group."
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group. In this case, the tensor parallel group must also be passed to the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager, either directly or as a subset of a larger distributed group."
]
},
{
......
......@@ -493,9 +493,6 @@ def reduce_tensor_across_group_op_max(
def global_amax_reduction(
fp8_meta: Dict[str, Any],
reduce_amax: bool = False,
reduce_amax_across_tp_group: bool = False,
tp_group: Optional[dist_group_type] = None,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
......@@ -509,10 +506,7 @@ def global_amax_reduction(
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
if reduce_amax:
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
if reduce_amax_across_tp_group:
reduce_tensor_across_group_op_max(contiguous_amax, tp_group)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
......
......@@ -101,15 +101,11 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
name: str = ""):
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None:
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
......@@ -125,15 +121,9 @@ def _prepare_backward(fp8: bool,
with torch.cuda.nvtx.range(name + " backward"):
yield
if fp8 and (fp8_meta["recipe"].reduce_amax or reduce_amax_across_tp_group):
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta,
fp8_meta["recipe"].reduce_amax,
reduce_amax_across_tp_group,
tp_group,
forward=False,
)
global_amax_reduction(fp8_meta, forward=False)
delete_key_from_amax_buffer(forward=False)
......@@ -456,16 +446,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
reduce_amax = self.fp8_meta["recipe"].reduce_amax or self.sequence_parallel
if self.fp8 and self.sequence_parallel and not self.fp8_meta["recipe"].reduce_amax:
warnings.warn(
"Amax reduction across tensor parallel group is necessary "
"when using sequence parallelism with FP8."
)
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if reduce_amax:
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
......@@ -478,7 +466,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training:
# Setup for amax reduction
if reduce_amax:
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
......@@ -509,16 +497,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and reduce_amax:
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.fp8_meta["recipe"].reduce_amax,
self.sequence_parallel,
self.tp_group,
forward=True,
)
reduce_func = partial(global_amax_reduction, self.fp8_meta, forward=True)
setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
......@@ -863,8 +844,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormLinear"):
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"):
(
inputmat,
ln_weight,
......@@ -1550,8 +1530,7 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_Linear"):
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"):
(
inputmat,
inputmat_t,
......@@ -2270,8 +2249,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormMLP"):
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"):
(
inputmat,
ln_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment