Unverified Commit 430d5d5a authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add main_grad (#779)



* support main_grad
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* update main_grad and fuse_wgrad_accumulation
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix ci errors
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent 53a3bc35
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder main_grad"""
import numpy as np
import pytest
import paddle
from paddle.distributed.fleet.utils import mix_precision_utils
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available
is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad):
'''Create optimizer'''
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.0001,
multi_precision=use_pure_bf16,
)
if use_main_grad:
optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)
return optimizer
class Net(paddle.nn.Layer):
'''Network use for main_grad testing'''
def __init__(self, fuse_wgrad_accumulation):
super().__init__()
self.layer = te.TransformerLayer(
4096,
16384,
32,
layer_type='encoder',
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
def forward(self, inp):
out = self.layer(inp)
return out
def train(enable_master_grad, fuse_wgrad_accumulation=False):
'''Train function'''
paddle.seed(10)
accumulate_steps = 4
if fuse_wgrad_accumulation:
assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad"
model = Net(fuse_wgrad_accumulation)
optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad)
loss_list = []
for step_id in range(16):
inp = paddle.uniform([2, 1024, 4096], dtype='float32')
inp.stop_gradient = False
with te.fp8_autocast(enabled=True):
out = model(inp)
loss = out.mean()
loss_list.append(loss)
loss.backward()
# gradient accumulation
if (step_id + 1) % accumulate_steps == 0:
optimizer.step()
optimizer.clear_grad()
return loss_list
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
'''Test main_grad'''
paddle.set_default_dtype('float32')
loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5)
......@@ -417,9 +417,9 @@ class TestGemm:
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = paddle.matmul(a, b.T)
actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype,
a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, fp8_dtype,
out_dtype, workspace)
actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT,
fp8_dtype, out_dtype, workspace)
assert_allclose(actual_out, ref_out)
......
......@@ -26,6 +26,7 @@ def gemm(
accumulate: bool = False,
layout: str = "TN",
out: Optional[paddle.Tensor] = None,
out_dtype: Optional[paddle.dtype] = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
......@@ -35,16 +36,23 @@ def gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
return_output = False
if out is None:
if accumulate:
out = paddle.zeros(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=out_dtype if out_dtype is not None else dtype,
)
else:
out = paddle.empty(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=dtype,
dtype=out_dtype if out_dtype is not None else dtype,
)
return_output = True
if gelu and not grad:
gelu_input = paddle.empty_like(out, dtype=dtype)
......@@ -94,9 +102,7 @@ def gemm(
0, # math_sm_count
)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
def fp8_gemm(
......@@ -125,8 +131,16 @@ def fp8_gemm(
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None:
if accumulate:
out = paddle.zeros(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
else:
out = paddle.empty(
shape=[
B.shape[0],
......@@ -134,7 +148,7 @@ def fp8_gemm(
],
dtype=out_dtype,
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
if gelu:
......@@ -172,13 +186,7 @@ def fp8_gemm(
0, # math_sm_count
)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def cast_to_fp8(
......
......@@ -562,6 +562,16 @@ class MultiHeadAttention(paddle.nn.Layer):
name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
"""
def __init__(
......@@ -584,6 +594,7 @@ class MultiHeadAttention(paddle.nn.Layer):
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
num_gqa_groups: Optional[int] = None,
fuse_wgrad_accumulation: bool = False,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
) -> None:
......@@ -633,6 +644,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
else:
......@@ -644,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
......@@ -661,6 +674,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
else:
......@@ -672,6 +686,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
self.key_value = Linear(
......@@ -682,6 +697,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
......@@ -706,6 +722,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode="row" if set_parallel_mode else None,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend,
)
......
......@@ -202,6 +202,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
......@@ -285,6 +286,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
......@@ -329,6 +331,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0],
ctx.parallel_mode == "row")
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Prepare ln_out for Linear bwd
linear_inputmat = ln_out
if ctx.fp8_enabled:
......@@ -365,6 +373,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
)
if not ctx.fp8_enabled:
......@@ -396,6 +406,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
wgrad = None
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
......@@ -449,6 +461,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
"""
def __init__(
......@@ -464,6 +485,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -497,6 +519,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.in_features],
......@@ -610,6 +634,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch,
)
......
......@@ -211,6 +211,8 @@ def _mlp_backward(
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
):
(
fc1_dgrad,
......@@ -238,6 +240,7 @@ def _mlp_backward(
fc2_input,
None,
fc2_input_fp8_index,
fc2_weight,
fc2_weight_t_fp8,
fc2_weight_fp8_index,
grad_output,
......@@ -253,6 +256,8 @@ def _mlp_backward(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
)
if activation == "gelu":
......@@ -299,6 +304,7 @@ def _mlp_backward(
fc1_input,
None,
fc1_input_fp8_index,
fc1_weight,
fc1_weight_t_fp8,
fc1_weight_fp8_index,
dgelu_no_fp8,
......@@ -314,6 +320,8 @@ def _mlp_backward(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
)
else:
dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8(
......@@ -328,6 +336,8 @@ def _mlp_backward(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
gelu_input=fc1_out,
activation=activation,
)
......@@ -347,6 +357,8 @@ def _mlp_backward(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
)
return (
fc1_dgrad,
......@@ -393,6 +405,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
......@@ -498,6 +511,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient
ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient
......@@ -548,6 +562,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
(
fc1_dgrad,
fc1_wgrad,
......@@ -585,6 +605,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
)
if not ctx.fp8_enabled:
# fc2_bias is fused with gemm for non-FP8 path
......@@ -619,6 +641,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc1_weight_cache_grad = (None, None)
fc2_weight_cache_grad = (None, None)
if ctx.requires_fc1_wgrad and ctx.fuse_wgrad_accumulation:
fc1_wgrad = None
if ctx.requires_fc2_wgrad and ctx.fuse_wgrad_accumulation:
fc2_wgrad = None
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
......@@ -679,6 +706,14 @@ class LayerNormMLP(TransformerEngineBaseLayer):
tp_group : paddle.distributed.collective.Group, default = `None`
tensor parallel process group.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
"""
def __init__(
......@@ -695,6 +730,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -720,6 +756,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.set_parallel_mode = set_parallel_mode
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if self.set_parallel_mode:
self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size)
else:
......@@ -876,6 +914,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch,
)
......
......@@ -96,7 +96,7 @@ def _linear_fwd_fp8(
out=weight_fp8,
)
out = fp8_gemm(
out, _ = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
weight_fp8_index,
......@@ -245,6 +245,7 @@ def _linear_bwd_fp8(
inputmat: paddle.Tensor,
inputmat_t: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_t_fp8: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
grad_output: paddle.Tensor,
......@@ -260,6 +261,8 @@ def _linear_bwd_fp8(
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
):
dgrad, wgrad, handle = None, None, None
......@@ -275,7 +278,7 @@ def _linear_bwd_fp8(
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad:
dgrad = fp8_gemm(
dgrad, _ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
weight_fp8_index,
......@@ -303,7 +306,8 @@ def _linear_bwd_fp8(
if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
clear_tensor_data(inputmat_total)
wgrad = fp8_gemm(
wgrad, _ = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
inputmat_fp8_index,
......@@ -312,8 +316,10 @@ def _linear_bwd_fp8(
fp8_meta["scaling_bwd"].scale_inv,
grad_output_fp8_index,
fp8_dtype_backward,
activation_dtype,
"float32" if fuse_wgrad_accumulation else activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
clear_tensor_data(inputmat_t_total, grad_output_t)
......@@ -323,11 +329,17 @@ def _linear_bwd_fp8(
grad_output,
activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
layout="NT",
out=weight.main_grad if fuse_wgrad_accumulation else None,
out_dtype="float32" if fuse_wgrad_accumulation else None,
)
clear_tensor_data(inputmat_total)
if fuse_wgrad_accumulation:
weight.main_grad = wgrad
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
......@@ -346,6 +358,8 @@ def _linear_bwd_non_fp8(
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "",
):
......@@ -386,10 +400,16 @@ def _linear_bwd_non_fp8(
grad_output,
activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
layout="NT",
out=weight.main_grad if fuse_wgrad_accumulation else None,
out_dtype="float32" if fuse_wgrad_accumulation else None,
use_bias=requires_bgrad,
)
if fuse_wgrad_accumulation:
weight.main_grad = wgrad
elif requires_bgrad:
bgrad = grad_output.sum(axis=0)
......@@ -421,6 +441,8 @@ def _linear_bwd(
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
):
dgrad, wgrad, bgrad = None, None, None
if fp8_enabled:
......@@ -428,6 +450,7 @@ def _linear_bwd(
inputmat,
inputmat_t,
inputmat_fp8_index,
weight,
weight_t_fp8,
weight_fp8_index,
grad_output,
......@@ -443,6 +466,8 @@ def _linear_bwd(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
)
else:
dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
......@@ -457,6 +482,8 @@ def _linear_bwd(
tensor_parallel,
sequence_parallel,
tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
)
return dgrad, wgrad, bgrad
......@@ -483,6 +510,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
......@@ -561,6 +589,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
......@@ -591,6 +620,11 @@ class _Linear(paddle.autograd.PyLayer):
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
ctx.parallel_mode == "row")
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad, wgrad, bgrad_ = _linear_bwd(
inputmat,
......@@ -614,6 +648,8 @@ class _Linear(paddle.autograd.PyLayer):
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
)
if not ctx.fp8_enabled:
......@@ -626,18 +662,22 @@ class _Linear(paddle.autograd.PyLayer):
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
if not ctx.use_bias:
return (
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
)
bgrad_return = ()
elif ctx.requires_bgrad:
bgrad_return = (bgrad,)
else:
bgrad_return = (None,)
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
wgrad = None
return (
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
bgrad if ctx.requires_bgrad else None,
dgrad_return,
*bgrad_return,
)
......@@ -668,6 +708,16 @@ class Linear(TransformerEngineBaseLayer):
When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
"""
def __init__(
......@@ -679,6 +729,7 @@ class Linear(TransformerEngineBaseLayer):
parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -705,6 +756,8 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
# Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
......@@ -779,6 +832,7 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch,
)
......
......@@ -100,6 +100,16 @@ class TransformerLayer(paddle.nn.Layer):
specified name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
"""
def __init__(self,
......@@ -124,6 +134,7 @@ class TransformerLayer(paddle.nn.Layer):
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None:
......@@ -168,6 +179,7 @@ class TransformerLayer(paddle.nn.Layer):
'max_sequence_length': max_sequence_length,
"tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
}
......@@ -202,6 +214,7 @@ class TransformerLayer(paddle.nn.Layer):
set_parallel_mode=set_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=backend,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment