Unverified Commit d8f678dc authored by Chen Cui's avatar Chen Cui Committed by GitHub
Browse files

Return layernorm output in the gathered form (#697)



* first draft of return_layernorm_output_gathered
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* explain use case more thoroughly in docstring
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* add same option in `LayerNormMLP`
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>

* address comments
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* add same option in LayerNormMLP
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* address linter errors
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

---------
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarChen Cui <cxcui@alumni.cmu.edu>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b0f65354
...@@ -77,6 +77,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -77,6 +77,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
return_layernorm_output: bool, return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
...@@ -134,11 +135,23 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -134,11 +135,23 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
is_grad_enabled) is_grad_enabled)
# Column Parallel Linear
ln_out_gathered = False
if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
# If residual connection is after LN, we need `ln_out_return` # If residual connection is after LN, we need `ln_out_return`
# tensor in higher precision, this comes at the cost # tensor in higher precision, this comes at the cost
# of an extra fp8 cast. # of an extra fp8 cast.
if return_layernorm_output: if return_layernorm_output:
ln_out_return = ln_out ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8: if fp8:
ln_out = tex.cast_to_fp8( ln_out = tex.cast_to_fp8(
ln_out, ln_out,
...@@ -146,14 +159,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -146,14 +159,6 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear
if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8: if fp8:
bias_dtype = ( bias_dtype = (
...@@ -284,6 +289,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -284,6 +289,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \
and ln_out_gathered
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
...@@ -302,6 +309,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -302,6 +309,10 @@ class _LayerNormLinear(torch.autograd.Function):
out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape[0] *= tp_size
return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp) return out, ln_out_return.view_as(inp)
return out return out
...@@ -445,6 +456,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -445,6 +456,8 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.ub_bulk_dgrad and handle is not None: if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait() handle.wait()
if not ctx.ub_bulk_wgrad: if not ctx.ub_bulk_wgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, handle = reduce_scatter_along_first_dim( dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True dgrad, ctx.tp_group, async_op=True
) )
...@@ -538,7 +551,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -538,7 +551,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad = dgrad.view(inputmat.shape) dgrad = dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
...@@ -611,6 +624,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -611,6 +624,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -638,6 +652,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -638,6 +652,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
return_layernorm_output_gathered : bool, default = `False`
if set to `True`, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided, multiple PyTorch parameters. If a list or tuple of strings is provided,
...@@ -711,6 +731,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -711,6 +731,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
...@@ -732,6 +753,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -732,6 +753,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
...@@ -1067,6 +1089,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1067,6 +1089,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
......
...@@ -105,6 +105,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -105,6 +105,7 @@ class _LayerNormMLP(torch.autograd.Function):
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
return_layernorm_output: bool, return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_nvfusion: bool, bias_gelu_nvfusion: bool,
set_parallel_mode: bool, set_parallel_mode: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
...@@ -174,11 +175,23 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -174,11 +175,23 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
is_grad_enabled) is_grad_enabled)
# Column Parallel Linear
ln_out_gathered = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
# If residual connection is after LN, we need `ln_out` # If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost # tensor in higher precision, this comes at the cost
# of an extra fp8 cast. # of an extra fp8 cast.
if return_layernorm_output: if return_layernorm_output:
ln_out_return = ln_out ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8: if fp8:
ln_out = tex.cast_to_fp8( ln_out = tex.cast_to_fp8(
ln_out, ln_out,
...@@ -186,14 +199,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -186,14 +199,6 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8: if fp8:
bias_dtype = ( bias_dtype = (
...@@ -503,6 +508,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -503,6 +508,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \
and ln_out_gathered
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
...@@ -525,6 +532,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -525,6 +532,10 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp) return fc2_out, ln_out_return.view_as(inp)
return fc2_out return fc2_out
...@@ -856,6 +867,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -856,6 +867,8 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.ub_bulk_dgrad and handle is not None: if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait() handle.wait()
if not ctx.ub_bulk_wgrad: if not ctx.ub_bulk_wgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
fc1_dgrad, handle = reduce_scatter_along_first_dim( fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True fc1_dgrad, ctx.tp_group, async_op=True
) )
...@@ -958,7 +971,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -958,7 +971,7 @@ class _LayerNormMLP(torch.autograd.Function):
dgrad = fc1_dgrad.view(inputmat.shape) dgrad = fc1_dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
...@@ -1058,6 +1071,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1058,6 +1071,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1093,6 +1107,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1093,6 +1107,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module Example use case: residual connection for transformer module
is taken post layernorm. is taken post layernorm.
return_layernorm_output_gathered : bool, default = `False`
if set to `True`, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -1166,6 +1186,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1166,6 +1186,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
...@@ -1189,6 +1210,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1189,6 +1210,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.bias_gelu_nvfusion = (bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and self.bias_gelu_nvfusion = (bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and
self.activation == 'gelu') self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
...@@ -1456,6 +1478,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1456,6 +1478,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment