Unverified Commit e3de4037 authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Enable DGRAD RS overlap (#754)



* Enable DGRAD RS overlap
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* fix lint; apply suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7d8ef9bf
...@@ -3171,6 +3171,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3171,6 +3171,7 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
bias: bool = True, bias: bool = True,
...@@ -3259,6 +3260,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3259,6 +3260,7 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
...@@ -3290,6 +3292,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3290,6 +3292,7 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
......
...@@ -130,6 +130,7 @@ def initialize_ub( ...@@ -130,6 +130,7 @@ def initialize_ub(
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
] ]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
...@@ -228,6 +229,14 @@ def initialize_ub( ...@@ -228,6 +229,14 @@ def initialize_ub(
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk':
wgrad_name = name.replace('dgrad','wgrad')
assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_reduce_scatter_overlap.append(name)
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name] ub_cfg = ub_cfgs[name]
......
...@@ -86,6 +86,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -86,6 +86,7 @@ class _LayerNormLinear(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
...@@ -316,6 +317,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -316,6 +317,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
...@@ -367,6 +369,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -367,6 +369,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
) )
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_overlap_rs_dgrad = False
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1 or not weight.requires_grad: if tp_world_size == 1 or not weight.requires_grad:
...@@ -416,9 +424,36 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -416,9 +424,36 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
elif ctx.ub_overlap_rs_dgrad:
ub_obj_dgrad = get_ub(ctx.ub_name+"_dgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else: else:
dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(1)
rs_out = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=grad_output.device)
if ub_obj_dgrad.is_p2p_overlap():
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad
else:
ub_algo = None
ub_obj = None
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fp8_meta["recipe"], fprop_tensor=True
...@@ -428,7 +463,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -428,7 +463,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
out_index, meta_tensor, out_te_type, out_type = ( out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype) None, None, None, ctx.activation_dtype)
if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf(): if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT1 out_index = tex.FP8BwdTensors.GRAD_INPUT1
meta_tensor = ctx.fp8_meta["scaling_bwd"] meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward out_te_type = fp8_dtype_backward
...@@ -449,8 +484,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -449,8 +484,9 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(), get_workspace(),
out=dgrad, out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
out_index=out_index, out_index=out_index,
fp8_meta_tensor = meta_tensor, fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type, D_dtype = out_te_type,
...@@ -466,8 +502,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -466,8 +502,9 @@ class _LayerNormLinear(torch.autograd.Function):
out=dgrad, out=dgrad,
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
) )
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
...@@ -476,7 +513,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -476,7 +513,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad and handle is not None: if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait() handle.wait()
if not ctx.ub_bulk_wgrad: if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, handle = reduce_scatter_along_first_dim( dgrad, handle = reduce_scatter_along_first_dim(
...@@ -569,7 +606,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -569,7 +606,10 @@ class _LayerNormLinear(torch.autograd.Function):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
dgrad = dgrad.view(inputmat.shape) if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out.view(inputmat.shape)
else:
dgrad = dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
...@@ -645,6 +685,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -645,6 +685,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -758,6 +799,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -758,6 +799,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -778,7 +820,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -778,7 +820,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]):
assert ub_name is not None, "Userbuffer name [string] is not set." assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name self.ub_name = ub_name
...@@ -1110,6 +1153,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1110,6 +1153,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.primary_weights_in_fp8, self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
) )
......
...@@ -117,6 +117,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -117,6 +117,7 @@ class _LayerNormMLP(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
...@@ -533,6 +534,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -533,6 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
...@@ -598,6 +600,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -598,6 +600,12 @@ class _LayerNormMLP(torch.autograd.Function):
activation_func = _act_func(ctx.activation)[1] activation_func = _act_func(ctx.activation)[1]
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_overlap_rs_dgrad = False
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1 or not fc1_weight.requires_grad: if tp_world_size == 1 or not fc1_weight.requires_grad:
...@@ -773,19 +781,49 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -773,19 +781,49 @@ class _LayerNormMLP(torch.autograd.Function):
None, None, None, ctx.activation_dtype) None, None, None, ctx.activation_dtype)
fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1) fc1_dgrad_size[1] = fc1_weight.size(1)
# Get/alloc fc1_dgrad
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad") ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
if ub_obj_dgrad.is_fp8_ubuf(): elif ctx.ub_overlap_rs_dgrad:
out_index = tex.FP8BwdTensors.GRAD_INPUT2 ub_obj_dgrad = get_ub("fc1_dgrad")
meta_tensor = ctx.fp8_meta["scaling_bwd"] fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
else: else:
fc1_dgrad = torch.empty( fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
) )
# FP8 RS
if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT2
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad:
dim_size = list(dgelu.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc1_weight_t_fp8.size(0)
rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device)
if ub_obj_dgrad.is_p2p_overlap():
if ub_obj_dgrad.is_atomic_gemm():
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad
else:
ub_algo = None
ub_obj = None
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc1_weight_t_fp8._data, fc1_weight_t_fp8._data,
...@@ -800,8 +838,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -800,8 +838,9 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
out=fc1_dgrad, out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
out_index=out_index, out_index=out_index,
fp8_meta_tensor = meta_tensor, fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type, D_dtype = out_te_type,
...@@ -859,11 +898,31 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -859,11 +898,31 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad") ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
elif ctx.ub_overlap_rs_dgrad:
ub_obj_dgrad = get_ub("fc1_dgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else: else:
fc1_dgrad = torch.empty( fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
) )
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad:
dim_size = list(dgelu.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc1_weight.size(1)
rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device)
if ub_obj_dgrad.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad
else:
ub_algo = None
ub_obj = None
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.gemm( _ = tex.gemm(
fc1_weight, fc1_weight,
...@@ -873,8 +932,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -873,8 +932,9 @@ class _LayerNormMLP(torch.autograd.Function):
out=fc1_dgrad, out=fc1_dgrad,
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
) )
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
...@@ -883,7 +943,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -883,7 +943,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.set_parallel_mode and ctx.sequence_parallel: if ctx.set_parallel_mode and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad and handle is not None: if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait() handle.wait()
if not ctx.ub_bulk_wgrad: if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
fc1_dgrad, handle = reduce_scatter_along_first_dim( fc1_dgrad, handle = reduce_scatter_along_first_dim(
...@@ -985,7 +1045,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -985,7 +1045,10 @@ class _LayerNormMLP(torch.autograd.Function):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
dgrad = fc1_dgrad.view(inputmat.shape) if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out.view(inputmat.shape)
else:
dgrad = fc1_dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
...@@ -1087,6 +1150,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1087,6 +1150,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1209,6 +1273,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1209,6 +1273,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
) -> None: ) -> None:
...@@ -1231,6 +1296,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1231,6 +1296,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
...@@ -1238,7 +1304,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1238,7 +1304,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
(bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm())
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]): if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag, ub_overlap_rs_dgrad]):
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
...@@ -1492,6 +1558,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1492,6 +1558,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8, self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.gemm_gelu_fusion, self.gemm_gelu_fusion,
......
...@@ -261,6 +261,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -261,6 +261,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad: bool = True, ub_bulk_dgrad: bool = True,
ub_overlap_ag: bool = True, ub_overlap_ag: bool = True,
ub_overlap_rs: bool = True, ub_overlap_rs: bool = True,
ub_overlap_rs_dgrad: bool = False,
bias: bool = True, bias: bool = True,
activation: str = 'gelu', activation: str = 'gelu',
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
...@@ -282,6 +283,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -282,6 +283,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
...@@ -357,6 +359,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -357,6 +359,7 @@ class TransformerLayer(torch.nn.Module):
"ub_bulk_dgrad" : ub_bulk_dgrad, "ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_overlap_ag" : ub_overlap_ag, "ub_overlap_ag" : ub_overlap_ag,
"ub_overlap_rs" : ub_overlap_rs, "ub_overlap_rs" : ub_overlap_rs,
"ub_overlap_rs_dgrad" : ub_overlap_rs_dgrad,
"qkv_format" : self.attn_input_format, "qkv_format" : self.attn_input_format,
} }
...@@ -410,6 +413,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -410,6 +413,7 @@ class TransformerLayer(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
ub_overlap_rs=ub_overlap_rs, ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
activation=activation, activation=activation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment