Unverified Commit d8f678dc authored by Chen Cui's avatar Chen Cui Committed by GitHub
Browse files

Return layernorm output in the gathered form (#697)



* first draft of return_layernorm_output_gathered
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* explain use case more thoroughly in docstring
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* add same option in `LayerNormMLP`
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>

* address comments
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* add same option in LayerNormMLP
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* address linter errors
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

---------
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b0f65354
......@@ -77,6 +77,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
......@@ -134,11 +135,23 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin,
zero_centered_gamma,
is_grad_enabled)
# Column Parallel Linear
ln_out_gathered = False
if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
# If residual connection is after LN, we need `ln_out_return`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if return_layernorm_output:
ln_out_return = ln_out
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8:
ln_out = tex.cast_to_fp8(
ln_out,
......@@ -146,14 +159,6 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
# Column Parallel Linear
if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
......@@ -284,6 +289,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \
and ln_out_gathered
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
......@@ -302,6 +309,10 @@ class _LayerNormLinear(torch.autograd.Function):
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape[0] *= tp_size
return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp)
return out
......@@ -445,6 +456,8 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait()
if not ctx.ub_bulk_wgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
......@@ -538,7 +551,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad = dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm":
......@@ -611,6 +624,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -638,6 +652,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
return_layernorm_output_gathered : bool, default = `False`
if set to `True`, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
......@@ -711,6 +731,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
......@@ -732,6 +753,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
......@@ -1067,6 +1089,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
......
......@@ -105,6 +105,7 @@ class _LayerNormMLP(torch.autograd.Function):
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
......@@ -174,11 +175,23 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin,
zero_centered_gamma,
is_grad_enabled)
# Column Parallel Linear
ln_out_gathered = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if return_layernorm_output:
ln_out_return = ln_out
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8:
ln_out = tex.cast_to_fp8(
ln_out,
......@@ -186,14 +199,6 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
# Column Parallel Linear
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
......@@ -503,6 +508,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \
and ln_out_gathered
ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
......@@ -525,6 +532,10 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp)
return fc2_out
......@@ -856,6 +867,8 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait()
if not ctx.ub_bulk_wgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
......@@ -958,7 +971,7 @@ class _LayerNormMLP(torch.autograd.Function):
dgrad = fc1_dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm":
......@@ -1058,6 +1071,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1093,6 +1107,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
together with the output of the linear transformation.
Example use case: residual connection for transformer module
is taken post layernorm.
return_layernorm_output_gathered : bool, default = `False`
if set to `True`, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -1166,6 +1186,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
......@@ -1189,6 +1210,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.bias_gelu_nvfusion = (bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and
self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode
......@@ -1456,6 +1478,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
torch.is_grad_enabled(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment