"transformer_engine/pytorch/quantization.py" did not exist on "85928d0887234a64c63b220e3c09d8a7a0d01c7b"
Unverified Commit 715c3bb8 authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

feat: Add support for multiple quantization modes in the UB communicators (#2043)

parent f98e3053
...@@ -62,3 +62,6 @@ pyTorch ...@@ -62,3 +62,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
\ No newline at end of file
...@@ -263,7 +263,9 @@ def _train(opts): ...@@ -263,7 +263,9 @@ def _train(opts):
te.module.base.initialize_ub( te.module.base.initialize_ub(
[batched_size, hidden_size], [batched_size, hidden_size],
tp_size, tp_size,
use_fp8=opts.fp8, quantization_modes=[
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
],
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
) )
......
...@@ -12,6 +12,8 @@ import argparse ...@@ -12,6 +12,8 @@ import argparse
import warnings import warnings
import pprint import pprint
import yaml import yaml
from contextlib import nullcontext
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module): ...@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
def forward(self, x): def forward(self, x, layer_contexts):
for layer in self.layers: for layer, context in zip(self.layers, layer_contexts):
with context():
x = layer(x) x = layer(x)
return x return x
...@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None): ...@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default=False, default=False,
help="Print out additional debug information.", help="Print out additional debug information.",
) )
parser.add_argument(
"--first-last-layers-bf16",
action="store_true",
default=False,
help="Use bf16 for first and last N layers.",
)
parser.add_argument(
"--num-layers-at-start-in-bf16",
type=int,
default=0,
help="Number of layers at the start to run in bf16.",
)
parser.add_argument(
"--num-layers-at-end-in-bf16",
type=int,
default=0,
help="Number of layers at the end to run in bf16.",
)
args = parser.parse_args(argv, namespace) args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False args.use_cuda_graphs = False
if not args.first_last_layers_bf16 and (
args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0
):
warnings.warn(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args.num_layers_at_start_in_bf16 = 0
args.num_layers_at_end_in_bf16 = 0
if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers:
raise ValueError(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return args return args
...@@ -381,10 +418,17 @@ def _train(opts): ...@@ -381,10 +418,17 @@ def _train(opts):
"qkv_dgrad": {"method": "ring_exchange"}, "qkv_dgrad": {"method": "ring_exchange"},
"fc1_dgrad": {"method": "ring_exchange"}, "fc1_dgrad": {"method": "ring_exchange"},
} }
quantization_modes = [
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
]
if opts.first_last_layers_bf16 and opts.fp8:
quantization_modes.append(UserBufferQuantizationMode.NONE)
te.module.base.initialize_ub( te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
opts.tp, opts.tp,
use_fp8=opts.fp8, quantization_modes=quantization_modes,
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
...@@ -423,6 +467,16 @@ def _train(opts): ...@@ -423,6 +467,16 @@ def _train(opts):
elif opts.quantization == "mxfp8": elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling() fp8_recipe = MXFP8BlockScaling()
layer_contexts = [
(
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext
)
for i in range(opts.num_layers)
]
# Prepare random input tensors # Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad() test_x.retain_grad()
...@@ -435,8 +489,7 @@ def _train(opts): ...@@ -435,8 +489,7 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test # Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x): def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16): with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): y = model(x, layer_contexts)
y = model(x)
if isinstance(y, tuple): if isinstance(y, tuple):
out, *_ = y out, *_ = y
else: else:
......
...@@ -506,7 +506,13 @@ def main() -> None: ...@@ -506,7 +506,13 @@ def main() -> None:
model_config.num_heads * model_config.head_dim, model_config.num_heads * model_config.head_dim,
], ],
torch.distributed.get_world_size(group), torch.distributed.get_world_size(group),
use_fp8=model_config.quantization is not None, quantization_modes=[
(
UserBufferQuantizationMode.FP8
if model_config.quantization is not None
else UserBufferQuantizationMode.NONE
)
],
dtype=model_config.dtype, dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend, bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs, ub_cfgs=userbuffer_configs,
......
...@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) { ...@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
} }
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1; if (comm->free_region >= NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region; int hndl = comm->free_region;
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize))); comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
size_t aligned_size = bytes; size_t aligned_size = bytes;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>; using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>; using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16 #define NVTE_MAX_REGIONS 32
#define NVTE_MAX_SMS 32 #define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32 #define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192 #define NVTE_MAX_PEERS 8192
......
...@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear ...@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.module import UserBufferQuantizationMode
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import InferenceParams
......
...@@ -11,4 +11,4 @@ from .layernorm import LayerNorm ...@@ -11,4 +11,4 @@ from .layernorm import LayerNorm
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
from .fp8_padding import Fp8Padding from .fp8_padding import Fp8Padding
from .fp8_unpadding import Fp8Unpadding from .fp8_unpadding import Fp8Unpadding
from .base import initialize_ub, destroy_ub from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import os import os
import pickle import pickle
import warnings import warnings
from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
...@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState ...@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
...@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None ...@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
class UserBufferQuantizationMode(Enum):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE = "none"
FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
...@@ -111,8 +121,9 @@ def initialize_ub( ...@@ -111,8 +121,9 @@ def initialize_ub(
shape: list, shape: list,
tp_size: int, tp_size: int,
use_fp8: bool = False, use_fp8: bool = False,
quantization_modes: List[UserBufferQuantizationMode] = None,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None, ub_cfgs: Optional[Union[dict, List[dict]]] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None: ) -> None:
r""" r"""
...@@ -128,7 +139,11 @@ def initialize_ub( ...@@ -128,7 +139,11 @@ def initialize_ub(
tp_size : int tp_size : int
number of GPUs in the tensor-parallel process group number of GPUs in the tensor-parallel process group
use_fp8 : bool = False use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16 dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False` non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None ub_cfgs: dict = None
...@@ -152,6 +167,7 @@ def initialize_ub( ...@@ -152,6 +167,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`. "fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and `torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are barrier collectives during Userbuffers initialization. Not all backends are
...@@ -168,6 +184,28 @@ def initialize_ub( ...@@ -168,6 +184,28 @@ def initialize_ub(
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
) )
if not quantization_modes:
warnings.warn(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead.",
DeprecationWarning,
)
quantization_modes = [
UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE
]
else:
assert isinstance(quantization_modes, list), "quantization_modes must be a list"
assert all(
isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes
), "quantization_modes must be a list of UserBufferQuantizationMode"
if isinstance(ub_cfgs, dict) or ub_cfgs is None:
ub_cfgs = [ub_cfgs] * len(quantization_modes)
else:
assert len(ub_cfgs) == len(
quantization_modes
), "Number of ub_cfgs settings must match number of quantization configurations"
global _ub_communicators global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized." assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {} _ub_communicators = {}
...@@ -309,6 +347,7 @@ def initialize_ub( ...@@ -309,6 +347,7 @@ def initialize_ub(
def add_ub( def add_ub(
name: str, name: str,
quantization_mode: UserBufferQuantizationMode,
method: str, method: str,
is_reduce_scatter: bool, is_reduce_scatter: bool,
num_sm: int = 16, num_sm: int = 16,
...@@ -327,7 +366,9 @@ def initialize_ub( ...@@ -327,7 +366,9 @@ def initialize_ub(
warnings.warn( warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases." "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
) )
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." assert (
quantization_mode == UserBufferQuantizationMode.FP8
), "Atomic GEMM overlap supported only for FP8 GEMM."
if method in ("bulk", "external"): if method in ("bulk", "external"):
warnings.warn( warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap." f"At {name}, atoimic GEMM not is supported for a bulk overlap."
...@@ -367,7 +408,11 @@ def initialize_ub( ...@@ -367,7 +408,11 @@ def initialize_ub(
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
) )
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype buffer_dtype = (
torch.uint8
if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf)
else dtype
)
if method == "ring_exchange": if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P( ub_obj = tex.CommOverlapP2P(
shape, # Communication buffer shape shape, # Communication buffer shape
...@@ -401,38 +446,47 @@ def initialize_ub( ...@@ -401,38 +446,47 @@ def initialize_ub(
comm_priority=comm_priority, comm_priority=comm_priority,
rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
) )
_ub_communicators[name] = ub_obj _ub_communicators[(name, quantization_mode)] = ub_obj
if ub_cfgs is not None: for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs):
if user_ub_cfg is not None:
for name in dgrad_reduce_scatter_overlap: for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": if (
name in user_ub_cfg
and "method" in user_ub_cfg[name]
and user_ub_cfg[name]["method"] != "bulk"
):
wgrad_name = name.replace("dgrad", "wgrad") wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs assert wgrad_name not in user_ub_cfg
layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name) layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name) layers_reduce_scatter_overlap.append(name)
methods["bulk"].remove(name) methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"] new_method = user_ub_cfg[name]["method"]
methods[new_method].append(name) methods[new_method].append(name)
for name in ( for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
): ):
ub_cfg = get_default_config(name) ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs: if user_ub_cfg is not None and name in user_ub_cfg:
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
) )
ub_cfg.update(ub_cfgs[name]) ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg) add_ub(name, quantization_mode, **ub_cfg)
def get_ub(name: str): def get_ub(name: str, use_fp8: bool):
"""Get userbuffer communicator corresponding to give key.""" """Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE)
assert _ub_communicators is not None, "UB manager is not initialized." assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered." assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered."
return _ub_communicators[name] return _ub_communicators[key]
def destroy_ub(): def destroy_ub():
......
...@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
) )
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop: elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output # Configure quantizer for norm output
...@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
# -------------------------------------------------- # --------------------------------------------------
...@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM # This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch = False is_first_microbatch = False
if self.ub_overlap_rs_fprop: if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True fp8_output = True
if self.ub_overlap_rs_dgrad: if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with torch.cuda.device(
......
...@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
# Copy into Userbuffers buffer # Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop", fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout, ub_obj_lnout,
ln_out, ln_out,
...@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2out = None ub_obj_fc2out = None
reduce_scatter_out = None reduce_scatter_out = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[0] //= tp_world_size dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0) dim_size[-1] = fc2_weight.size(0)
...@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad") ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8)
ctx.ub_obj_gradout = ub_obj_fc2_dgrad ctx.ub_obj_gradout = ub_obj_fc2_dgrad
( (
grad_output, grad_output,
...@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad, ub_obj_fc1_dgrad,
ln_out, ln_out,
...@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad.get_communication_stream() ub_obj_fc2_dgrad.get_communication_stream()
) )
ub_obj_fc2_wgrad = get_ub("fc2_wgrad") ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8)
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS # Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.RS ub_type_fc1_dgrad = tex.CommOverlapType.RS
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute # Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.AG ub_type_fc1_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute # Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad") ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8)
ub_type_fc1_wgrad = tex.CommOverlapType.RS ub_type_fc1_wgrad = tex.CommOverlapType.RS
# -------------------------------------------------- # --------------------------------------------------
...@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu" and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) and all(
("fc1_fprop", use_fp8) not in _ub_communicators
or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm()
for use_fp8 in [False, True]
)
) )
self.name = name self.name = name
...@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = False fp8_output = False
if self.ub_overlap_rs: if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with torch.cuda.device( with torch.cuda.device(
......
...@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function): ...@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function):
ub_obj = None ub_obj = None
ub_type = None ub_type = None
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop: elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------ # ------------------------------------------------------
...@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function): ...@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
# -------------------------------------------------- # --------------------------------------------------
...@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function): ...@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM # This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch = False is_first_microbatch = False
if self.ub_overlap_rs_fprop: if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True fp8_output = True
if self.ub_overlap_rs_dgrad: if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with torch.cuda.device(
......
...@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x = False with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column": elif tensor_parallel_mode == "column":
if input_requires_grad and weight_requires_grad: if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad") ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
ub_type_wgrad = CommOverlapType.RS ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf(): if ub_comm_wgrad.is_fp8_ubuf():
...@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers" "Userbuffers reduce-scatter is not supported with FP8 buffers"
) )
else: else:
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.RS ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf(): if ub_comm_dgrad.is_fp8_ubuf():
...@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG # Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad") ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......
...@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer = None output_quantizer = None
# Get Userbuffers communicator # Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop") ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute)
with_ub_all_gather = tensor_parallel_mode == "column" with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row" with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment