"vscode:/vscode.git/clone" did not exist on "90bac96321e3f189f4b60074c1e4ad782dd3ceb6"
Unverified Commit d3d7ed2c authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Amax reduction interval (#154)



* amax reduction internval
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Skip TP-domain only AMAX reduction when TP-group is not initialized
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Update transformer_engine/pytorch/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* check TP group initialized
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

fix
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b2b3fbe7
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""FP8 utilities for TransformerEngine""" """FP8 utilities for TransformerEngine"""
import os
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union from typing import Callable, List, Optional, Dict, Any, Tuple, Union
...@@ -30,6 +31,9 @@ _buffer_delete_key_bwd = None ...@@ -30,6 +31,9 @@ _buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None _amax_reduce_handle_fwd = None
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
_dp_amax_reduce_interval = None
_dp_amax_reduce_forward_idx = 0
_dp_amax_reduce_backward_idx = 0
def _check_fp8_support() -> Tuple[bool, str]: def _check_fp8_support() -> Tuple[bool, str]:
...@@ -545,6 +549,8 @@ def reduce_tensor_across_group_op_max( ...@@ -545,6 +549,8 @@ def reduce_tensor_across_group_op_max(
def global_amax_reduction( def global_amax_reduction(
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True, forward: bool = True,
) -> None: ) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer.""" """Concatenate, reduce, and split amaxes in the global buffer."""
...@@ -555,12 +561,37 @@ def global_amax_reduction( ...@@ -555,12 +561,37 @@ def global_amax_reduction(
if amax_buffer_key not in _global_fp8_buffer: if amax_buffer_key not in _global_fp8_buffer:
return None return None
# Reduce AMAX in DP-domain at an interval.
global _dp_amax_reduce_interval, _dp_amax_reduce_forward_idx, _dp_amax_reduce_backward_idx
if _dp_amax_reduce_interval is None:
_dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if _dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_forward_idx = (_dp_amax_reduce_forward_idx + 1) % _dp_amax_reduce_interval
else:
if _dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_backward_idx = (_dp_amax_reduce_backward_idx + 1) % _dp_amax_reduce_interval
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]] chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key]) contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
wait_handle = reduce_tensor_across_group_op_max( wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax, contiguous_amax,
fp8_meta["fp8_group"], reduce_group,
fp8_meta["async_amax_reduction"], fp8_meta["async_amax_reduction"],
) )
......
...@@ -105,7 +105,13 @@ def get_workspace() -> torch.Tensor: ...@@ -105,7 +105,13 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace return _cublas_workspace
@contextmanager @contextmanager
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None: def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> None:
"""Checks and prep for BWD.""" """Checks and prep for BWD."""
if fp8: if fp8:
global _amax_reduce_handle_bwd global _amax_reduce_handle_bwd
...@@ -132,7 +138,12 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N ...@@ -132,7 +138,12 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N
if fp8 and fp8_meta["recipe"].reduce_amax: if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False) _amax_reduce_handle_bwd = global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
delete_key_from_amax_buffer(forward=False) delete_key_from_amax_buffer(forward=False)
...@@ -186,7 +197,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -186,7 +197,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False self.fp8_meta_tensors_initialized = False
self.tp_group = None self.tp_group = None
self.tp_group_initialized = False
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
...@@ -541,7 +551,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -541,7 +551,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(global_amax_reduction, self.fp8_meta, forward=True) reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
setup_amax_forward_global_reduce_func(reduce_func) setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
...@@ -692,6 +708,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -692,6 +708,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
...@@ -867,6 +884,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -867,6 +884,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
...@@ -890,7 +908,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -890,7 +908,9 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"): with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -1065,6 +1085,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1065,6 +1085,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1381,6 +1402,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1381,6 +1402,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.tp_size,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
...@@ -1427,6 +1449,7 @@ class _Linear(torch.autograd.Function): ...@@ -1427,6 +1449,7 @@ class _Linear(torch.autograd.Function):
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
...@@ -1563,6 +1586,7 @@ class _Linear(torch.autograd.Function): ...@@ -1563,6 +1586,7 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
...@@ -1579,7 +1603,9 @@ class _Linear(torch.autograd.Function): ...@@ -1579,7 +1603,9 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"): with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
( (
inputmat, inputmat,
inputmat_t, inputmat_t,
...@@ -1730,6 +1756,7 @@ class _Linear(torch.autograd.Function): ...@@ -1730,6 +1756,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1995,6 +2022,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1995,6 +2022,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.tp_size,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
...@@ -2039,6 +2067,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2039,6 +2067,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
...@@ -2282,6 +2311,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2282,6 +2311,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
...@@ -2307,7 +2337,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2307,7 +2337,9 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"): with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -2610,6 +2642,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2610,6 +2642,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -2904,6 +2937,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2904,6 +2937,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.tp_size,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment