Unverified Commit 18da4e88 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

TP communication overlap with userbuffers (#147)



* Port initial changes
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* readd FA include for PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-enable sm_70 + cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* LICENSE, cleanup header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* 5k -> 173 errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license and fixes in userbuffers-host
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* final cpp cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pylinting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix from linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Turn off default async amax reduction (#148)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code path
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup Macros
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix conflict resolution bug
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix gencode flags in setup (#145)

* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change overlap config dict error message
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* simplify ub initialization
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix sanity imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpplint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TensorFlow build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TE macros in public header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compiles with and w/o MPI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes for python side annotations for conditional compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* link gdrAPI only when MPI found
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix comments for dummy var
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* load MPI before TE
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add Py side argument checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code and catch silent failures
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix find_lib path for tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 7bb2af35
......@@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A,
workspace,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg);
use_split_accumulator_arg,
0);
return D;
}
......
......@@ -85,6 +85,8 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
......@@ -147,6 +149,105 @@ def _prepare_backward(
delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
ub_cfgs: Optional[dict] = None
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
def get_method(name):
for method, names in methods.items():
if name in names:
return method
raise KeyError(f"Given layer name {name} does not exist.")
def add_ub(
name: str,
method: str,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
aggregate: int = 0,
) -> None:
dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
_ub_communicators[name] = ub_obj
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
add_ub(
name,
method,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
......@@ -596,9 +697,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
if not ctx.ub_split_ag:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
else:
ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(
......@@ -610,6 +715,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
assert (
not ctx.ub_split_ag
), "override_linear_precision.wgrad not supported with ub_split_ag"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output:
......@@ -617,14 +725,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
grad_output_c = cast_to_fp8(
if ctx.ub_split_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
out=grad_output_c,
)
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if not ctx.ub_split_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
......@@ -718,6 +835,9 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -733,16 +853,26 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
......@@ -752,6 +882,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out
)
else:
mu = rsigma = None
......@@ -783,17 +914,25 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
......@@ -838,6 +977,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
else:
# Cast for native AMP
......@@ -859,6 +1001,9 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(),
bias=bias,
use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
if is_grad_enabled:
......@@ -888,6 +1033,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
......@@ -922,6 +1069,15 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
(
grad_output,
grad_output_c,
......@@ -931,9 +1087,14 @@ class _LayerNormLinear(torch.autograd.Function):
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
......@@ -947,6 +1108,15 @@ class _LayerNormLinear(torch.autograd.Function):
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
......@@ -956,7 +1126,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
_ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -967,25 +1137,35 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
_, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
if not ctx.ub_bulk_dgrad:
handle.wait()
if not ctx.ub_bulk_wgrad:
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
......@@ -1008,6 +1188,9 @@ class _LayerNormLinear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
ln_out_total_c = cast_from_fp8(
......@@ -1026,6 +1209,9 @@ class _LayerNormLinear(torch.autograd.Function):
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
# WGRAD
......@@ -1039,10 +1225,15 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
......@@ -1086,6 +1277,9 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
None,
None,
)
......@@ -1179,6 +1373,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
......@@ -1190,6 +1387,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
......@@ -1308,6 +1513,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
......@@ -1412,6 +1618,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
)
out = fwd_fn(*args)
......@@ -1455,6 +1664,8 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -1466,6 +1677,10 @@ class _Linear(torch.autograd.Function):
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat
......@@ -1529,7 +1744,19 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward,
)
out = fp8_gemm(
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -1543,6 +1770,10 @@ class _Linear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
else:
# Cast for native AMP
......@@ -1557,13 +1788,29 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm(
weight,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
if is_grad_enabled:
......@@ -1586,11 +1833,14 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
if ub_split_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
......@@ -1614,6 +1864,14 @@ class _Linear(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad")
(
grad_output,
grad_output_c,
......@@ -1667,6 +1925,8 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
else:
dgrad, _, _ = gemm(
......@@ -1676,6 +1936,8 @@ class _Linear(torch.autograd.Function):
get_workspace(),
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
# Overlap dgrad-RS/AR with wgrad
......@@ -1691,6 +1953,8 @@ class _Linear(torch.autograd.Function):
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
......@@ -1757,6 +2021,8 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -1838,6 +2104,8 @@ class Linear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
......@@ -1847,6 +2115,13 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
......@@ -2028,6 +2303,8 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.ub_split_rs,
self.ub_split_ag,
)
out = linear_fn(*args)
......@@ -2078,6 +2355,10 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -2094,6 +2375,18 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
......@@ -2101,7 +2394,9 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
......@@ -2111,6 +2406,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out,
)
else:
ln_out = layernorm_fwd_fp8_inf(
......@@ -2135,9 +2431,15 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
......@@ -2145,7 +2447,10 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_return = ln_out
# Column Parallel Linear
if set_parallel_mode and sequence_parallel:
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
......@@ -2208,6 +2513,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
gelu_out = fp8_gelu(
......@@ -2217,7 +2525,19 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
fc2_out = fp8_gemm(
if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_ = fp8_gemm(
fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
......@@ -2231,6 +2551,10 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias,
use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
else:
# Cast for native AMP
......@@ -2259,6 +2583,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
if bias_gelu_nvfusion:
......@@ -2276,14 +2603,30 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float()
fc2_out, _, _ = gemm(
if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_, _, _ = gemm(
fc2_weight,
gelu_out,
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_fc2_bias,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
if is_grad_enabled:
ctx.save_for_backward(
inputmat,
......@@ -2317,10 +2660,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if set_parallel_mode and sequence_parallel:
if ub_split_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
fc2_out, _ = allreduce(fc2_out, tp_group)
......@@ -2356,6 +2704,24 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad")
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
(
grad_output,
......@@ -2365,10 +2731,13 @@ class _LayerNormMLP(torch.autograd.Function):
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True
)
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
......@@ -2403,8 +2772,11 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
......@@ -2469,8 +2841,17 @@ class _LayerNormMLP(torch.autograd.Function):
)
dgelu_t = None
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
_ = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -2481,7 +2862,10 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
else:
# FC2 DGRAD; Unconditional
......@@ -2494,6 +2878,8 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
# FC2 WGRAD
......@@ -2515,22 +2901,38 @@ class _LayerNormMLP(torch.autograd.Function):
else:
dgelu = fc2_dgrad
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm(
_, _, _ = gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
get_workspace(),
out=fc1_dgrad,
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
handle.wait()
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
if not ctx.ub_bulk_dgrad:
handle.wait()
if not ctx.ub_bulk_wgrad:
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
......@@ -2555,6 +2957,9 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
)
else:
ln_out_total_c = cast_from_fp8(
......@@ -2575,6 +2980,9 @@ class _LayerNormMLP(torch.autograd.Function):
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
)
else:
# FC1 WGRAD
......@@ -2588,6 +2996,8 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.bias_gelu_nvfusion:
......@@ -2596,7 +3006,9 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
if ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
......@@ -2643,6 +3055,10 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
None,
None,
)
......@@ -2741,6 +3157,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
......@@ -2752,6 +3172,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
......@@ -2948,6 +3377,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_split_ag,
)
out = fwd_fn(*args)
......
......@@ -15,6 +15,7 @@ import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
......@@ -495,6 +496,10 @@ class MultiHeadAttention(torch.nn.Module):
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
bias: bool = True,
) -> None:
super().__init__()
......@@ -547,6 +552,9 @@ class MultiHeadAttention(torch.nn.Module):
return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
......@@ -572,6 +580,9 @@ class MultiHeadAttention(torch.nn.Module):
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
......@@ -616,6 +627,8 @@ class MultiHeadAttention(torch.nn.Module):
bias=bias,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
......@@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
ub_bulk_wgrad: bool, default = False
Bulk overlap UserBuffer ReduceScatter | WGRAD GEMM
ub_bulk_dgrad: bool, default = False
Bulk overlap UserBuffer AllGather | DGRAD GEMM
ub_split_ag: bool, default = False
Split pipelined overlap UserBuffer AllGather -> GEMM
Optimization parameters
-----------------------
......@@ -970,6 +989,7 @@ class TransformerLayer(torch.nn.Module):
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True,
) -> None:
super().__init__()
......@@ -980,6 +1000,16 @@ class TransformerLayer(torch.nn.Module):
category=DeprecationWarning,
)
if ub_tp_comm_overlap:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
self.output_layernorm = output_layernorm
......@@ -1037,6 +1067,10 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved" : qkv_weight_interleaved,
"ub_bulk_wgrad" : ub_bulk_wgrad,
"ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs,
}
self.self_attention = MultiHeadAttention(
......@@ -1080,6 +1114,10 @@ class TransformerLayer(torch.nn.Module):
micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
)
self.hidden_dropout = hidden_dropout
......
......@@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper(
nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(),
bias_tensor.data(), gelu_input_tensor.data(), transa,
transb, grad, workspace_tensor.data(), accumulate,
use_split_accumulate, stream);
use_split_accumulate, 0, stream);
auto d_eager = CreateTensor(d_ptr, d_shape, otype);
if (use_gelu && !grad) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment