Unverified Commit 715c3bb8 authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

feat: Add support for multiple quantization modes in the UB communicators (#2043)

parent f98e3053
......@@ -62,3 +62,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
\ No newline at end of file
......@@ -263,7 +263,9 @@ def _train(opts):
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_size,
use_fp8=opts.fp8,
quantization_modes=[
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
],
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
......
......@@ -12,6 +12,8 @@ import argparse
import warnings
import pprint
import yaml
from contextlib import nullcontext
from functools import partial
import torch
import torch.distributed as dist
......@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module):
self.num_layers = num_layers
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
def forward(self, x, layer_contexts):
for layer, context in zip(self.layers, layer_contexts):
with context():
x = layer(x)
return x
......@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default=False,
help="Print out additional debug information.",
)
parser.add_argument(
"--first-last-layers-bf16",
action="store_true",
default=False,
help="Use bf16 for first and last N layers.",
)
parser.add_argument(
"--num-layers-at-start-in-bf16",
type=int,
default=0,
help="Number of layers at the start to run in bf16.",
)
parser.add_argument(
"--num-layers-at-end-in-bf16",
type=int,
default=0,
help="Number of layers at the end to run in bf16.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
if not args.first_last_layers_bf16 and (
args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0
):
warnings.warn(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args.num_layers_at_start_in_bf16 = 0
args.num_layers_at_end_in_bf16 = 0
if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers:
raise ValueError(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return args
......@@ -381,10 +418,17 @@ def _train(opts):
"qkv_dgrad": {"method": "ring_exchange"},
"fc1_dgrad": {"method": "ring_exchange"},
}
quantization_modes = [
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
]
if opts.first_last_layers_bf16 and opts.fp8:
quantization_modes.append(UserBufferQuantizationMode.NONE)
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
opts.tp,
use_fp8=opts.fp8,
quantization_modes=quantization_modes,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
......@@ -423,6 +467,16 @@ def _train(opts):
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
layer_contexts = [
(
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext
)
for i in range(opts.num_layers)
]
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
......@@ -435,8 +489,7 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
y = model(x, layer_contexts)
if isinstance(y, tuple):
out, *_ = y
else:
......
......@@ -506,7 +506,13 @@ def main() -> None:
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.quantization is not None,
quantization_modes=[
(
UserBufferQuantizationMode.FP8
if model_config.quantization is not None
else UserBufferQuantizationMode.NONE
)
],
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
......
......@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
if (comm->free_region >= NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region;
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
size_t aligned_size = bytes;
......
......@@ -27,7 +27,7 @@
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_REGIONS 32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
......
......@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.module import UserBufferQuantizationMode
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention import InferenceParams
......
......@@ -11,4 +11,4 @@ from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
from .fp8_padding import Fp8Padding
from .fp8_unpadding import Fp8Unpadding
from .base import initialize_ub, destroy_ub
from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode
......@@ -8,6 +8,7 @@ import math
import os
import pickle
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
......@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
__all__ = ["initialize_ub", "destroy_ub"]
__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
......@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
class UserBufferQuantizationMode(Enum):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE = "none"
FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
......@@ -111,8 +121,9 @@ def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
quantization_modes: List[UserBufferQuantizationMode] = None,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None,
ub_cfgs: Optional[Union[dict, List[dict]]] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
r"""
......@@ -128,7 +139,11 @@ def initialize_ub(
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
......@@ -152,6 +167,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
......@@ -168,6 +184,28 @@ def initialize_ub(
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
if not quantization_modes:
warnings.warn(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead.",
DeprecationWarning,
)
quantization_modes = [
UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE
]
else:
assert isinstance(quantization_modes, list), "quantization_modes must be a list"
assert all(
isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes
), "quantization_modes must be a list of UserBufferQuantizationMode"
if isinstance(ub_cfgs, dict) or ub_cfgs is None:
ub_cfgs = [ub_cfgs] * len(quantization_modes)
else:
assert len(ub_cfgs) == len(
quantization_modes
), "Number of ub_cfgs settings must match number of quantization configurations"
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
......@@ -309,6 +347,7 @@ def initialize_ub(
def add_ub(
name: str,
quantization_mode: UserBufferQuantizationMode,
method: str,
is_reduce_scatter: bool,
num_sm: int = 16,
......@@ -327,7 +366,9 @@ def initialize_ub(
warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
assert (
quantization_mode == UserBufferQuantizationMode.FP8
), "Atomic GEMM overlap supported only for FP8 GEMM."
if method in ("bulk", "external"):
warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
......@@ -367,7 +408,11 @@ def initialize_ub(
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
buffer_dtype = (
torch.uint8
if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf)
else dtype
)
if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P(
shape, # Communication buffer shape
......@@ -401,38 +446,47 @@ def initialize_ub(
comm_priority=comm_priority,
rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
)
_ub_communicators[name] = ub_obj
_ub_communicators[(name, quantization_mode)] = ub_obj
if ub_cfgs is not None:
for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs):
if user_ub_cfg is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
if (
name in user_ub_cfg
and "method" in user_ub_cfg[name]
and user_ub_cfg[name]["method"] != "bulk"
):
wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs
assert wgrad_name not in user_ub_cfg
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"]
new_method = user_ub_cfg[name]["method"]
methods[new_method].append(name)
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:
if user_ub_cfg is not None and name in user_ub_cfg:
fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg)
add_ub(name, quantization_mode, **ub_cfg)
def get_ub(name: str):
def get_ub(name: str, use_fp8: bool):
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE)
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered."
return _ub_communicators[key]
def destroy_ub():
......
......@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output
......@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
......
......@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
# Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout = get_ub("fc1_fprop", fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout,
ln_out,
......@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2out = None
reduce_scatter_out = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size())
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
......@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8)
ctx.ub_obj_gradout = ub_obj_fc2_dgrad
(
grad_output,
......@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad,
ln_out,
......@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad.get_communication_stream()
)
ub_obj_fc2_wgrad = get_ub("fc2_wgrad")
ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8)
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8)
ub_type_fc1_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
and all(
("fc1_fprop", use_fp8) not in _ub_communicators
or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm()
for use_fp8 in [False, True]
)
)
self.name = name
......@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf():
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True
with torch.cuda.device(
......
......@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function):
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
......@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
......
......@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row":
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column":
if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
......@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else:
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf():
......@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......
......@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer = None
# Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop")
ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute)
with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment