Unverified Commit d3d7ed2c authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Amax reduction interval (#154)



* amax reduction internval
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Skip TP-domain only AMAX reduction when TP-group is not initialized
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Update transformer_engine/pytorch/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* check TP group initialized
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

fix
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b2b3fbe7
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
import os
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
......@@ -30,6 +31,9 @@ _buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
_dp_amax_reduce_interval = None
_dp_amax_reduce_forward_idx = 0
_dp_amax_reduce_backward_idx = 0
def _check_fp8_support() -> Tuple[bool, str]:
......@@ -545,6 +549,8 @@ def reduce_tensor_across_group_op_max(
def global_amax_reduction(
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
......@@ -555,12 +561,37 @@ def global_amax_reduction(
if amax_buffer_key not in _global_fp8_buffer:
return None
# Reduce AMAX in DP-domain at an interval.
global _dp_amax_reduce_interval, _dp_amax_reduce_forward_idx, _dp_amax_reduce_backward_idx
if _dp_amax_reduce_interval is None:
_dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if _dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_forward_idx = (_dp_amax_reduce_forward_idx + 1) % _dp_amax_reduce_interval
else:
if _dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_backward_idx = (_dp_amax_reduce_backward_idx + 1) % _dp_amax_reduce_interval
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
fp8_meta["fp8_group"],
reduce_group,
fp8_meta["async_amax_reduction"],
)
......
......@@ -105,7 +105,13 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None:
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> None:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
......@@ -132,7 +138,12 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False)
_amax_reduce_handle_bwd = global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
delete_key_from_amax_buffer(forward=False)
......@@ -186,7 +197,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False
self.tp_group = None
self.tp_group_initialized = False
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
......@@ -541,7 +551,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(global_amax_reduction, self.fp8_meta, forward=True)
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
......@@ -692,6 +708,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
......@@ -867,6 +884,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
......@@ -890,7 +908,9 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"):
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
(
inputmat,
ln_weight,
......@@ -1065,6 +1085,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1381,6 +1402,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
......@@ -1427,6 +1449,7 @@ class _Linear(torch.autograd.Function):
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
......@@ -1563,6 +1586,7 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
......@@ -1579,7 +1603,9 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"):
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
(
inputmat,
inputmat_t,
......@@ -1730,6 +1756,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1995,6 +2022,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
......@@ -2039,6 +2067,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
......@@ -2282,6 +2311,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode
......@@ -2307,7 +2337,9 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"):
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
(
inputmat,
ln_weight,
......@@ -2610,6 +2642,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2904,6 +2937,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment