Unverified Commit 18da4e88 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

TP communication overlap with userbuffers (#147)



* Port initial changes
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* readd FA include for PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-enable sm_70 + cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* LICENSE, cleanup header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* 5k -> 173 errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license and fixes in userbuffers-host
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* final cpp cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pylinting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix from linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Turn off default async amax reduction (#148)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code path
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup Macros
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix conflict resolution bug
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix gencode flags in setup (#145)

* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change overlap config dict error message
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* simplify ub initialization
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix sanity imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpplint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TensorFlow build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TE macros in public header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compiles with and w/o MPI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes for python side annotations for conditional compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* link gdrAPI only when MPI found
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix comments for dummy var
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* load MPI before TE
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add Py side argument checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code and catch silent failures
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix find_lib path for tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 7bb2af35
...@@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A,
workspace, workspace,
workspaceSize_arg, workspaceSize_arg,
accumulate_arg, accumulate_arg,
use_split_accumulator_arg); use_split_accumulator_arg,
0);
return D; return D;
} }
......
...@@ -85,6 +85,8 @@ _2X_ACC_FPROP = False ...@@ -85,6 +85,8 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None _amax_reduce_handle_bwd = None
...@@ -147,6 +149,105 @@ def _prepare_backward( ...@@ -147,6 +149,105 @@ def _prepare_backward(
delete_key_from_amax_buffer(forward=False) delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
ub_cfgs: Optional[dict] = None
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
def get_method(name):
for method, names in methods.items():
if name in names:
return method
raise KeyError(f"Given layer name {name} does not exist.")
def add_ub(
name: str,
method: str,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
aggregate: int = 0,
) -> None:
dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
_ub_communicators[name] = ub_obj
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
add_ub(
name,
method,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
class _NoopCat(torch.autograd.Function): class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`.""" """This class is a no-op replacement for `torch.cat`."""
...@@ -596,9 +697,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -596,9 +697,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8:
if gather_grad_output: if gather_grad_output:
grad_output_mat, _ = gather_along_first_dim( if not ctx.ub_split_ag:
grad_output_mat, ctx.tp_group grad_output_mat, _ = gather_along_first_dim(
) grad_output_mat, ctx.tp_group
)
else:
ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
return grad_output_mat, None, None, None return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
...@@ -610,6 +715,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -610,6 +715,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output gather_grad_output
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
): ):
assert (
not ctx.ub_split_ag
), "override_linear_precision.wgrad not supported with ub_split_ag"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output: elif gather_grad_output:
...@@ -617,14 +725,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -617,14 +725,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0) grad_bias = grad_output_mat.sum(dim=0)
else: else:
grad_bias = None grad_bias = None
grad_output_c = cast_to_fp8( if ctx.ub_split_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c,
) )
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) if not ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
...@@ -718,6 +835,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -718,6 +835,9 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -733,16 +853,26 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -733,16 +853,26 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out` # If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost # tensor in higher precision, this comes at the cost
# of an extra fp8 cast. # of an extra fp8 cast.
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8( if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat, inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
...@@ -752,6 +882,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -752,6 +882,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
ln_out = ln_out
) )
else: else:
mu = rsigma = None mu = rsigma = None
...@@ -783,17 +914,25 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -783,17 +914,25 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( if ub_split_ag:
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma _, mu, rsigma = tex.layernorm_fwd_noalloc(
) inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None ), None, None
ln_out_return = ln_out ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
if parallel_mode == "column" and sequence_parallel: if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -838,6 +977,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -838,6 +977,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -859,6 +1001,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -859,6 +1001,9 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -888,6 +1033,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -888,6 +1033,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
...@@ -922,6 +1069,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -922,6 +1069,15 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -931,9 +1087,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -931,9 +1087,14 @@ class _LayerNormLinear(torch.autograd.Function):
ctx, grad_outputs[0], ctx.parallel_mode == "row" ctx, grad_outputs[0], ctx.parallel_mode == "row"
) )
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear # Column Parallel Linear
# Overlap input AG with dgrad # Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim( ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True ln_out, ctx.tp_group, async_op=True
) )
...@@ -947,6 +1108,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -947,6 +1108,15 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fp8_meta["recipe"], fprop_tensor=True
...@@ -956,7 +1126,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -956,7 +1126,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm( _ = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -967,25 +1137,35 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -967,25 +1137,35 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
) )
else: else:
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm( _, _, _ = gemm(
weight, weight,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
out=dgrad,
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
) )
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait() if not ctx.ub_bulk_dgrad:
dgrad, handle = reduce_scatter_along_first_dim( handle.wait()
dgrad, ctx.tp_group, async_op=True if not ctx.ub_bulk_wgrad:
) dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
...@@ -1008,6 +1188,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1008,6 +1188,9 @@ class _LayerNormLinear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
else: else:
ln_out_total_c = cast_from_fp8( ln_out_total_c = cast_from_fp8(
...@@ -1026,6 +1209,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1026,6 +1209,9 @@ class _LayerNormLinear(torch.autograd.Function):
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
else: else:
# WGRAD # WGRAD
...@@ -1039,10 +1225,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1039,10 +1225,15 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -1086,6 +1277,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1086,6 +1277,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
None,
) )
...@@ -1179,6 +1373,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1179,6 +1373,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -1190,6 +1387,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1190,6 +1387,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1308,6 +1513,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1308,6 +1513,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias: if self.parallel_mode == "row" and self.apply_bias:
...@@ -1412,6 +1618,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1412,6 +1618,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1455,6 +1664,8 @@ class _Linear(torch.autograd.Function): ...@@ -1455,6 +1664,8 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -1466,6 +1677,10 @@ class _Linear(torch.autograd.Function): ...@@ -1466,6 +1677,10 @@ class _Linear(torch.autograd.Function):
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
...@@ -1529,7 +1744,19 @@ class _Linear(torch.autograd.Function): ...@@ -1529,7 +1744,19 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
out = fp8_gemm( if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -1543,6 +1770,10 @@ class _Linear(torch.autograd.Function): ...@@ -1543,6 +1770,10 @@ class _Linear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -1557,13 +1788,29 @@ class _Linear(torch.autograd.Function): ...@@ -1557,13 +1788,29 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float() torch.amax(weight).float()
out, _, _ = gemm( if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm(
weight, weight,
inputmat_total, inputmat_total,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -1586,11 +1833,14 @@ class _Linear(torch.autograd.Function): ...@@ -1586,11 +1833,14 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if ub_split_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel: elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
...@@ -1614,6 +1864,14 @@ class _Linear(torch.autograd.Function): ...@@ -1614,6 +1864,14 @@ class _Linear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad")
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -1667,6 +1925,8 @@ class _Linear(torch.autograd.Function): ...@@ -1667,6 +1925,8 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
) )
else: else:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
...@@ -1676,6 +1936,8 @@ class _Linear(torch.autograd.Function): ...@@ -1676,6 +1936,8 @@ class _Linear(torch.autograd.Function):
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
) )
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
...@@ -1691,6 +1953,8 @@ class _Linear(torch.autograd.Function): ...@@ -1691,6 +1953,8 @@ class _Linear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm( wgrad = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -1757,6 +2021,8 @@ class _Linear(torch.autograd.Function): ...@@ -1757,6 +2021,8 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -1838,6 +2104,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1838,6 +2104,8 @@ class Linear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -1847,6 +2115,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -1847,6 +2115,13 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -2028,6 +2303,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -2028,6 +2303,8 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.ub_split_rs,
self.ub_split_ag,
) )
out = linear_fn(*args) out = linear_fn(*args)
...@@ -2078,6 +2355,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2078,6 +2355,10 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -2094,6 +2375,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2094,6 +2375,18 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# If residual connection is after LN, we need `ln_out` # If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost # tensor in higher precision, this comes at the cost
# of an extra fp8 cast. # of an extra fp8 cast.
...@@ -2101,7 +2394,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2101,7 +2394,9 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8( if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat, inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
...@@ -2111,6 +2406,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2111,6 +2406,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
ln_out = ln_out,
) )
else: else:
ln_out = layernorm_fwd_fp8_inf( ln_out = layernorm_fwd_fp8_inf(
...@@ -2135,9 +2431,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2135,9 +2431,15 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( if ub_split_ag:
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma _, mu, rsigma = tex.layernorm_fwd_noalloc(
) inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
...@@ -2145,7 +2447,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2145,7 +2447,10 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_return = ln_out ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
if set_parallel_mode and sequence_parallel: if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -2208,6 +2513,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2208,6 +2513,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=use_fc1_bias, use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
) )
gelu_out = fp8_gelu( gelu_out = fp8_gelu(
...@@ -2217,7 +2525,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2217,7 +2525,19 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
fc2_out = fp8_gemm( if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_ = fp8_gemm(
fc2_weight_fp8, fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
...@@ -2231,6 +2551,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2231,6 +2551,10 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias, bias=fc2_bias,
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -2259,6 +2583,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2259,6 +2583,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion, gelu=not bias_gelu_nvfusion,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
) )
if bias_gelu_nvfusion: if bias_gelu_nvfusion:
...@@ -2276,14 +2603,30 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2276,14 +2603,30 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float() torch.amax(fc2_weight).float()
fc2_out, _, _ = gemm( if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_, _, _ = gemm(
fc2_weight, fc2_weight,
gelu_out, gelu_out,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=fc2_bias, bias=fc2_bias,
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
) )
if is_grad_enabled: if is_grad_enabled:
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
...@@ -2317,10 +2660,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2317,10 +2660,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if ub_split_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel: elif set_parallel_mode and tensor_parallel:
fc2_out, _ = allreduce(fc2_out, tp_group) fc2_out, _ = allreduce(fc2_out, tp_group)
...@@ -2356,6 +2704,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2356,6 +2704,24 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad")
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
( (
grad_output, grad_output,
...@@ -2365,10 +2731,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2365,10 +2731,13 @@ class _LayerNormMLP(torch.autograd.Function):
) = TransformerEngineBaseModule.grad_output_preprocess( ) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True ctx, grad_outputs[0], True
) )
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear # Column Parallel Linear
# Overlap input AG with dgrad # Overlap input AG with dgrad
if ctx.set_parallel_mode and ctx.sequence_parallel: if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim( ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True ln_out, ctx.tp_group, async_op=True
) )
...@@ -2403,8 +2772,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2403,8 +2772,11 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
) )
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
# FC2 WGRAD # FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad: if fc2_weight.requires_grad:
...@@ -2469,8 +2841,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2469,8 +2841,17 @@ class _LayerNormMLP(torch.autograd.Function):
) )
dgelu_t = None dgelu_t = None
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm( _ = fp8_gemm(
fc1_weight_t_fp8, fc1_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -2481,7 +2862,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2481,7 +2862,10 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
) )
else: else:
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
...@@ -2494,6 +2878,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2494,6 +2878,8 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not ctx.bias_gelu_nvfusion, gelu=not ctx.bias_gelu_nvfusion,
grad=True, grad=True,
gelu_input=fc1_out, gelu_input=fc1_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
) )
# FC2 WGRAD # FC2 WGRAD
...@@ -2515,22 +2901,38 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2515,22 +2901,38 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
dgelu = fc2_dgrad dgelu = fc2_dgrad
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm( _, _, _ = gemm(
fc1_weight, fc1_weight,
dgelu, dgelu,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
out=fc1_dgrad,
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
) )
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel: if ctx.set_parallel_mode and ctx.sequence_parallel:
handle.wait() if not ctx.ub_bulk_dgrad:
fc1_dgrad, handle = reduce_scatter_along_first_dim( handle.wait()
fc1_dgrad, ctx.tp_group, async_op=True if not ctx.ub_bulk_wgrad:
) fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode and ctx.tensor_parallel: elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
...@@ -2555,6 +2957,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2555,6 +2957,9 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
) )
else: else:
ln_out_total_c = cast_from_fp8( ln_out_total_c = cast_from_fp8(
...@@ -2575,6 +2980,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2575,6 +2980,9 @@ class _LayerNormMLP(torch.autograd.Function):
out=fc1_weight.main_grad out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
) )
else: else:
# FC1 WGRAD # FC1 WGRAD
...@@ -2588,6 +2996,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2588,6 +2996,8 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=not ctx.bias_gelu_nvfusion, use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
if ctx.bias_gelu_nvfusion: if ctx.bias_gelu_nvfusion:
...@@ -2596,7 +3006,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2596,7 +3006,9 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear # Column Parallel Linear
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: if ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -2643,6 +3055,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2643,6 +3055,10 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
None,
None,
) )
...@@ -2741,6 +3157,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2741,6 +3157,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -2752,6 +3172,15 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2752,6 +3172,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -2948,6 +3377,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2948,6 +3377,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_split_ag,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -15,6 +15,7 @@ import torch ...@@ -15,6 +15,7 @@ import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
...@@ -495,6 +496,10 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -495,6 +496,10 @@ class MultiHeadAttention(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
bias: bool = True, bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -547,6 +552,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -547,6 +552,9 @@ class MultiHeadAttention(torch.nn.Module):
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -572,6 +580,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -572,6 +580,9 @@ class MultiHeadAttention(torch.nn.Module):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -616,6 +627,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -616,6 +627,8 @@ class MultiHeadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=True, return_bias=True,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the `set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
ub_bulk_wgrad: bool, default = False
Bulk overlap UserBuffer ReduceScatter | WGRAD GEMM
ub_bulk_dgrad: bool, default = False
Bulk overlap UserBuffer AllGather | DGRAD GEMM
ub_split_ag: bool, default = False
Split pipelined overlap UserBuffer AllGather -> GEMM
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -970,6 +989,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -970,6 +989,7 @@ class TransformerLayer(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True, bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -980,6 +1000,16 @@ class TransformerLayer(torch.nn.Module): ...@@ -980,6 +1000,16 @@ class TransformerLayer(torch.nn.Module):
category=DeprecationWarning, category=DeprecationWarning,
) )
if ub_tp_comm_overlap:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
self.output_layernorm = output_layernorm self.output_layernorm = output_layernorm
...@@ -1037,6 +1067,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -1037,6 +1067,10 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved" : qkv_weight_interleaved, "qkv_weight_interleaved" : qkv_weight_interleaved,
"ub_bulk_wgrad" : ub_bulk_wgrad,
"ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs,
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(
...@@ -1080,6 +1114,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -1080,6 +1114,10 @@ class TransformerLayer(torch.nn.Module):
micro_batch_size=micro_batch_size, micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode, set_parallel_mode=set_parallel_mode,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
......
...@@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper( ...@@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper(
nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(), nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(),
bias_tensor.data(), gelu_input_tensor.data(), transa, bias_tensor.data(), gelu_input_tensor.data(), transa,
transb, grad, workspace_tensor.data(), accumulate, transb, grad, workspace_tensor.data(), accumulate,
use_split_accumulate, stream); use_split_accumulate, 0, stream);
auto d_eager = CreateTensor(d_ptr, d_shape, otype); auto d_eager = CreateTensor(d_ptr, d_shape, otype);
if (use_gelu && !grad) { if (use_gelu && !grad) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment