Unverified Commit 430d5d5a authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add main_grad (#779)



* support main_grad
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* update main_grad and fuse_wgrad_accumulation
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix ci errors
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent 53a3bc35
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder main_grad"""
import numpy as np
import pytest
import paddle
from paddle.distributed.fleet.utils import mix_precision_utils
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available
is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad):
'''Create optimizer'''
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.0001,
multi_precision=use_pure_bf16,
)
if use_main_grad:
optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)
return optimizer
class Net(paddle.nn.Layer):
'''Network use for main_grad testing'''
def __init__(self, fuse_wgrad_accumulation):
super().__init__()
self.layer = te.TransformerLayer(
4096,
16384,
32,
layer_type='encoder',
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
def forward(self, inp):
out = self.layer(inp)
return out
def train(enable_master_grad, fuse_wgrad_accumulation=False):
'''Train function'''
paddle.seed(10)
accumulate_steps = 4
if fuse_wgrad_accumulation:
assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad"
model = Net(fuse_wgrad_accumulation)
optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad)
loss_list = []
for step_id in range(16):
inp = paddle.uniform([2, 1024, 4096], dtype='float32')
inp.stop_gradient = False
with te.fp8_autocast(enabled=True):
out = model(inp)
loss = out.mean()
loss_list.append(loss)
loss.backward()
# gradient accumulation
if (step_id + 1) % accumulate_steps == 0:
optimizer.step()
optimizer.clear_grad()
return loss_list
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
'''Test main_grad'''
paddle.set_default_dtype('float32')
loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5)
...@@ -417,9 +417,9 @@ class TestGemm: ...@@ -417,9 +417,9 @@ class TestGemm:
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = paddle.matmul(a, b.T) ref_out = paddle.matmul(a, b.T)
actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype, actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT,
a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, fp8_dtype, fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT,
out_dtype, workspace) fp8_dtype, out_dtype, workspace)
assert_allclose(actual_out, ref_out) assert_allclose(actual_out, ref_out)
......
...@@ -26,6 +26,7 @@ def gemm( ...@@ -26,6 +26,7 @@ def gemm(
accumulate: bool = False, accumulate: bool = False,
layout: str = "TN", layout: str = "TN",
out: Optional[paddle.Tensor] = None, out: Optional[paddle.Tensor] = None,
out_dtype: Optional[paddle.dtype] = None,
bias: Optional[paddle.Tensor] = None, bias: Optional[paddle.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]: ) -> Tuple[Union[paddle.Tensor, None], ...]:
...@@ -35,16 +36,23 @@ def gemm( ...@@ -35,16 +36,23 @@ def gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
return_output = False
if out is None: if out is None:
out = paddle.empty( if accumulate:
shape=[ out = paddle.zeros(
B.shape[1] if transb else B.shape[0], shape=[
A.shape[0] if transa else A.shape[1], B.shape[1] if transb else B.shape[0],
], A.shape[0] if transa else A.shape[1],
dtype=dtype, ],
) dtype=out_dtype if out_dtype is not None else dtype,
return_output = True )
else:
out = paddle.empty(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=out_dtype if out_dtype is not None else dtype,
)
if gelu and not grad: if gelu and not grad:
gelu_input = paddle.empty_like(out, dtype=dtype) gelu_input = paddle.empty_like(out, dtype=dtype)
...@@ -94,9 +102,7 @@ def gemm( ...@@ -94,9 +102,7 @@ def gemm(
0, # math_sm_count 0, # math_sm_count
) )
if return_output: return out, grad_bias, gelu_input
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
def fp8_gemm( def fp8_gemm(
...@@ -125,16 +131,24 @@ def fp8_gemm( ...@@ -125,16 +131,24 @@ def fp8_gemm(
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None: if out is None:
out = paddle.empty( if accumulate:
shape=[ out = paddle.zeros(
B.shape[0], shape=[
A.shape[0], B.shape[0],
], A.shape[0],
dtype=out_dtype, ],
) dtype=out_dtype,
return_output = True )
else:
out = paddle.empty(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = paddle.bfloat16 if bias is None else bias.dtype bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
if gelu: if gelu:
...@@ -172,13 +186,7 @@ def fp8_gemm( ...@@ -172,13 +186,7 @@ def fp8_gemm(
0, # math_sm_count 0, # math_sm_count
) )
if return_output: return out, gelu_input
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def cast_to_fp8( def cast_to_fp8(
......
...@@ -562,6 +562,16 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -562,6 +562,16 @@ class MultiHeadAttention(paddle.nn.Layer):
name should be registered through name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`. .add(rng_state_name, seed)`.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
""" """
def __init__( def __init__(
...@@ -584,6 +594,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -584,6 +594,7 @@ class MultiHeadAttention(paddle.nn.Layer):
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
fuse_wgrad_accumulation: bool = False,
rng_state_name: str = 'local_seed', rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
...@@ -633,6 +644,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -633,6 +644,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
else: else:
...@@ -644,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -644,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
...@@ -661,6 +674,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -661,6 +674,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
else: else:
...@@ -672,6 +686,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -672,6 +686,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
self.key_value = Linear( self.key_value = Linear(
...@@ -682,6 +697,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -682,6 +697,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
...@@ -706,6 +722,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -706,6 +722,7 @@ class MultiHeadAttention(paddle.nn.Layer):
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=self.backend, backend=self.backend,
) )
......
...@@ -202,6 +202,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -202,6 +202,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool, is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm": if normalization == "RMSNorm":
...@@ -285,6 +286,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -285,6 +286,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
...@@ -329,6 +331,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -329,6 +331,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0],
ctx.parallel_mode == "row") ctx.parallel_mode == "row")
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Prepare ln_out for Linear bwd # Prepare ln_out for Linear bwd
linear_inputmat = ln_out linear_inputmat = ln_out
if ctx.fp8_enabled: if ctx.fp8_enabled:
...@@ -365,6 +373,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -365,6 +373,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel, ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
) )
if not ctx.fp8_enabled: if not ctx.fp8_enabled:
...@@ -396,14 +406,16 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -396,14 +406,16 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
# weight_fp8 and weight_t_fp8 are stop_gradient tensors # weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None) weight_cache_grad = (None, None)
return ( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, wgrad = None
dgamma if ctx.requires_ln_wgrad else None, return (
*dbeta_out, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
wgrad if ctx.requires_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
*weight_cache_grad, *dbeta_out,
*bgrad_out, wgrad if ctx.requires_wgrad else None,
) *weight_cache_grad,
*bgrad_out,
)
class LayerNormLinear(TransformerEngineBaseLayer): class LayerNormLinear(TransformerEngineBaseLayer):
...@@ -449,6 +461,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -449,6 +461,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism. if set to `True`, uses sequence parallelism.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
""" """
def __init__( def __init__(
...@@ -464,6 +485,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -464,6 +485,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None, tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -497,6 +519,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -497,6 +519,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
# LayerNorm weights # LayerNorm weights
self.ln_weight = self.create_parameter( self.ln_weight = self.create_parameter(
shape=[self.in_features], shape=[self.in_features],
...@@ -610,6 +634,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -610,6 +634,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch, is_first_microbatch,
) )
......
...@@ -211,6 +211,8 @@ def _mlp_backward( ...@@ -211,6 +211,8 @@ def _mlp_backward(
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
): ):
( (
fc1_dgrad, fc1_dgrad,
...@@ -238,6 +240,7 @@ def _mlp_backward( ...@@ -238,6 +240,7 @@ def _mlp_backward(
fc2_input, fc2_input,
None, None,
fc2_input_fp8_index, fc2_input_fp8_index,
fc2_weight,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc2_weight_fp8_index, fc2_weight_fp8_index,
grad_output, grad_output,
...@@ -253,6 +256,8 @@ def _mlp_backward( ...@@ -253,6 +256,8 @@ def _mlp_backward(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
) )
if activation == "gelu": if activation == "gelu":
...@@ -299,6 +304,7 @@ def _mlp_backward( ...@@ -299,6 +304,7 @@ def _mlp_backward(
fc1_input, fc1_input,
None, None,
fc1_input_fp8_index, fc1_input_fp8_index,
fc1_weight,
fc1_weight_t_fp8, fc1_weight_t_fp8,
fc1_weight_fp8_index, fc1_weight_fp8_index,
dgelu_no_fp8, dgelu_no_fp8,
...@@ -314,6 +320,8 @@ def _mlp_backward( ...@@ -314,6 +320,8 @@ def _mlp_backward(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
) )
else: else:
dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8(
...@@ -328,6 +336,8 @@ def _mlp_backward( ...@@ -328,6 +336,8 @@ def _mlp_backward(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
gelu_input=fc1_out, gelu_input=fc1_out,
activation=activation, activation=activation,
) )
...@@ -347,6 +357,8 @@ def _mlp_backward( ...@@ -347,6 +357,8 @@ def _mlp_backward(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
) )
return ( return (
fc1_dgrad, fc1_dgrad,
...@@ -393,6 +405,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -393,6 +405,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool, is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm": if normalization == "RMSNorm":
...@@ -498,6 +511,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -498,6 +511,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient
ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient
...@@ -548,6 +562,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -548,6 +562,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad, fc2_bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True) ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
( (
fc1_dgrad, fc1_dgrad,
fc1_wgrad, fc1_wgrad,
...@@ -585,6 +605,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -585,6 +605,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel, ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
) )
if not ctx.fp8_enabled: if not ctx.fp8_enabled:
# fc2_bias is fused with gemm for non-FP8 path # fc2_bias is fused with gemm for non-FP8 path
...@@ -619,17 +641,22 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -619,17 +641,22 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc1_weight_cache_grad = (None, None) fc1_weight_cache_grad = (None, None)
fc2_weight_cache_grad = (None, None) fc2_weight_cache_grad = (None, None)
return ( if ctx.requires_fc1_wgrad and ctx.fuse_wgrad_accumulation:
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, fc1_wgrad = None
dgamma if ctx.requires_ln_wgrad else None, if ctx.requires_fc2_wgrad and ctx.fuse_wgrad_accumulation:
*dbeta_out, fc2_wgrad = None
fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_weight_cache_grad, return (
*fc1_bgrad_out, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
fc2_wgrad if ctx.requires_fc2_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
*fc2_weight_cache_grad, *dbeta_out,
*fc2_bgrad_out, fc1_wgrad if ctx.requires_fc1_wgrad else None,
) *fc1_weight_cache_grad,
*fc1_bgrad_out,
fc2_wgrad if ctx.requires_fc2_wgrad else None,
*fc2_weight_cache_grad,
*fc2_bgrad_out,
)
class LayerNormMLP(TransformerEngineBaseLayer): class LayerNormMLP(TransformerEngineBaseLayer):
...@@ -679,6 +706,14 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -679,6 +706,14 @@ class LayerNormMLP(TransformerEngineBaseLayer):
tp_group : paddle.distributed.collective.Group, default = `None` tp_group : paddle.distributed.collective.Group, default = `None`
tensor parallel process group. tensor parallel process group.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
""" """
def __init__( def __init__(
...@@ -695,6 +730,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -695,6 +730,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -720,6 +756,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -720,6 +756,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.sequence_parallel = self.tensor_parallel and sequence_parallel self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if self.set_parallel_mode: if self.set_parallel_mode:
self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size)
else: else:
...@@ -876,6 +914,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -876,6 +914,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch, is_first_microbatch,
) )
......
...@@ -96,7 +96,7 @@ def _linear_fwd_fp8( ...@@ -96,7 +96,7 @@ def _linear_fwd_fp8(
out=weight_fp8, out=weight_fp8,
) )
out = fp8_gemm( out, _ = fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
weight_fp8_index, weight_fp8_index,
...@@ -245,6 +245,7 @@ def _linear_bwd_fp8( ...@@ -245,6 +245,7 @@ def _linear_bwd_fp8(
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
inputmat_t: paddle.Tensor, inputmat_t: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors, inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_t_fp8: paddle.Tensor, weight_t_fp8: paddle.Tensor,
weight_fp8_index: FP8FwdTensors, weight_fp8_index: FP8FwdTensors,
grad_output: paddle.Tensor, grad_output: paddle.Tensor,
...@@ -260,6 +261,8 @@ def _linear_bwd_fp8( ...@@ -260,6 +261,8 @@ def _linear_bwd_fp8(
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
): ):
dgrad, wgrad, handle = None, None, None dgrad, wgrad, handle = None, None, None
...@@ -275,7 +278,7 @@ def _linear_bwd_fp8( ...@@ -275,7 +278,7 @@ def _linear_bwd_fp8(
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad: if requires_dgrad:
dgrad = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
weight_fp8_index, weight_fp8_index,
...@@ -303,7 +306,8 @@ def _linear_bwd_fp8( ...@@ -303,7 +306,8 @@ def _linear_bwd_fp8(
if inputmat_t_total is None: if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward) inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_total)
wgrad = fp8_gemm(
wgrad, _ = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
inputmat_fp8_index, inputmat_fp8_index,
...@@ -312,8 +316,10 @@ def _linear_bwd_fp8( ...@@ -312,8 +316,10 @@ def _linear_bwd_fp8(
fp8_meta["scaling_bwd"].scale_inv, fp8_meta["scaling_bwd"].scale_inv,
grad_output_fp8_index, grad_output_fp8_index,
fp8_dtype_backward, fp8_dtype_backward,
activation_dtype, "float32" if fuse_wgrad_accumulation else activation_dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
clear_tensor_data(inputmat_t_total, grad_output_t) clear_tensor_data(inputmat_t_total, grad_output_t)
...@@ -323,11 +329,17 @@ def _linear_bwd_fp8( ...@@ -323,11 +329,17 @@ def _linear_bwd_fp8(
grad_output, grad_output,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
layout="NT",
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
layout="NT",
out=weight.main_grad if fuse_wgrad_accumulation else None,
out_dtype="float32" if fuse_wgrad_accumulation else None,
) )
clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_total)
if fuse_wgrad_accumulation:
weight.main_grad = wgrad
if parallel_mode == "column" and tensor_parallel and handle is not None: if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait() handle.wait()
...@@ -346,6 +358,8 @@ def _linear_bwd_non_fp8( ...@@ -346,6 +358,8 @@ def _linear_bwd_non_fp8(
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
gelu_input: Union[paddle.Tensor, None] = None, gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "", activation: str = "",
): ):
...@@ -386,10 +400,16 @@ def _linear_bwd_non_fp8( ...@@ -386,10 +400,16 @@ def _linear_bwd_non_fp8(
grad_output, grad_output,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
layout="NT",
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
layout="NT",
out=weight.main_grad if fuse_wgrad_accumulation else None,
out_dtype="float32" if fuse_wgrad_accumulation else None,
use_bias=requires_bgrad, use_bias=requires_bgrad,
) )
if fuse_wgrad_accumulation:
weight.main_grad = wgrad
elif requires_bgrad: elif requires_bgrad:
bgrad = grad_output.sum(axis=0) bgrad = grad_output.sum(axis=0)
...@@ -421,6 +441,8 @@ def _linear_bwd( ...@@ -421,6 +441,8 @@ def _linear_bwd(
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
fuse_wgrad_accumulation: bool,
accumulate_wgrad_into_param_main_grad: bool,
): ):
dgrad, wgrad, bgrad = None, None, None dgrad, wgrad, bgrad = None, None, None
if fp8_enabled: if fp8_enabled:
...@@ -428,6 +450,7 @@ def _linear_bwd( ...@@ -428,6 +450,7 @@ def _linear_bwd(
inputmat, inputmat,
inputmat_t, inputmat_t,
inputmat_fp8_index, inputmat_fp8_index,
weight,
weight_t_fp8, weight_t_fp8,
weight_fp8_index, weight_fp8_index,
grad_output, grad_output,
...@@ -443,6 +466,8 @@ def _linear_bwd( ...@@ -443,6 +466,8 @@ def _linear_bwd(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
) )
else: else:
dgrad, wgrad, bgrad = _linear_bwd_non_fp8( dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
...@@ -457,6 +482,8 @@ def _linear_bwd( ...@@ -457,6 +482,8 @@ def _linear_bwd(
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
) )
return dgrad, wgrad, bgrad return dgrad, wgrad, bgrad
...@@ -483,6 +510,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -483,6 +510,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool, is_first_microbatch: bool,
) -> paddle.Tensor: ) -> paddle.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -561,6 +589,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -561,6 +589,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
...@@ -591,6 +620,11 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -591,6 +620,11 @@ class _Linear(paddle.autograd.PyLayer):
bgrad, bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output, ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
ctx.parallel_mode == "row") ctx.parallel_mode == "row")
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad, wgrad, bgrad_ = _linear_bwd( dgrad, wgrad, bgrad_ = _linear_bwd(
inputmat, inputmat,
...@@ -614,6 +648,8 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -614,6 +648,8 @@ class _Linear(paddle.autograd.PyLayer):
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel, ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
ctx.fuse_wgrad_accumulation,
accumulate_wgrad_into_param_main_grad,
) )
if not ctx.fp8_enabled: if not ctx.fp8_enabled:
...@@ -626,19 +662,23 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -626,19 +662,23 @@ class _Linear(paddle.autograd.PyLayer):
# weight_fp8 and weight_t_fp8 are stop_gradient tensors # weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None) weight_cache_grad = (None, None)
dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
if not ctx.use_bias: if not ctx.use_bias:
return ( bgrad_return = ()
wgrad if ctx.requires_wgrad else None, elif ctx.requires_bgrad:
*weight_cache_grad, bgrad_return = (bgrad,)
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, else:
) bgrad_return = (None,)
return ( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
wgrad if ctx.requires_wgrad else None, wgrad = None
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, return (
bgrad if ctx.requires_bgrad else None, wgrad if ctx.requires_wgrad else None,
) *weight_cache_grad,
dgrad_return,
*bgrad_return,
)
class Linear(TransformerEngineBaseLayer): class Linear(TransformerEngineBaseLayer):
...@@ -668,6 +708,16 @@ class Linear(TransformerEngineBaseLayer): ...@@ -668,6 +708,16 @@ class Linear(TransformerEngineBaseLayer):
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism. if set to `True`, uses sequence parallelism.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
""" """
def __init__( def __init__(
...@@ -679,6 +729,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -679,6 +729,7 @@ class Linear(TransformerEngineBaseLayer):
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None, tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -705,6 +756,8 @@ class Linear(TransformerEngineBaseLayer): ...@@ -705,6 +756,8 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
# Initialize weight parameter # Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major # TE linear weight is in column major
...@@ -779,6 +832,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -779,6 +832,7 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch, is_first_microbatch,
) )
......
...@@ -100,6 +100,16 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -100,6 +100,16 @@ class TransformerLayer(paddle.nn.Layer):
specified name should be registered through specified name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`. .add(rng_state_name, seed)`.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
""" """
def __init__(self, def __init__(self,
...@@ -124,6 +134,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -124,6 +134,7 @@ class TransformerLayer(paddle.nn.Layer):
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
attention_dropout_rng_state_name: str = 'local_seed', attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed', hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None: backend: str = 'transformer_engine') -> None:
...@@ -168,6 +179,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -168,6 +179,7 @@ class TransformerLayer(paddle.nn.Layer):
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
"tp_group": tp_group, "tp_group": tp_group,
"num_gqa_groups": num_gqa_groups, "num_gqa_groups": num_gqa_groups,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"rng_state_name": attention_dropout_rng_state_name, "rng_state_name": attention_dropout_rng_state_name,
"backend": backend, "backend": backend,
} }
...@@ -202,6 +214,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -202,6 +214,7 @@ class TransformerLayer(paddle.nn.Layer):
set_parallel_mode=set_parallel_mode, set_parallel_mode=set_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
backend=backend, backend=backend,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment