Unverified Commit 71725099 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add RMSNorm, RoPE and SwiGLU (#599)



* use separate qkv
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

add support for GQA
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

change rtol
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

fix reshape issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add rmsnorm and rotary position embedding
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

update rmsnorm
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

refactor layernorm and rmsnorm
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

support swiglu
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add fused rope
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add rope api to __init__
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

fix fp8 dtype issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* simplify ut cases
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* Update transformer_engine/paddle/layer/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarShijie <505749828@qq.com>

* fix name issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>
Signed-off-by: default avatarShijie <505749828@qq.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 2187a8f3
...@@ -301,8 +301,9 @@ class TestLayerNormLinear: ...@@ -301,8 +301,9 @@ class TestLayerNormLinear:
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
no_wgrad, return_ln_out, activation_dtype): no_wgrad, return_ln_out, activation_dtype, normalization):
""" """
Test BF16 LayerNormLinear Layer Test BF16 LayerNormLinear Layer
""" """
...@@ -314,11 +315,13 @@ class TestLayerNormLinear: ...@@ -314,11 +315,13 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
layer_te = te.LayerNormLinear( layer_te = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
eps=eps, eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
) )
...@@ -327,12 +330,14 @@ class TestLayerNormLinear: ...@@ -327,12 +330,14 @@ class TestLayerNormLinear:
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
eps=eps, eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend='paddle',
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True) layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.weight.copy_(layer_te.weight.T, True) layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias: if has_bias:
...@@ -340,9 +345,10 @@ class TestLayerNormLinear: ...@@ -340,9 +345,10 @@ class TestLayerNormLinear:
layer_te.weight.stop_gradient = no_wgrad layer_te.weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad layer_te.ln_weight.stop_gradient = no_wgrad
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.weight.stop_gradient = no_wgrad layer_pd.weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_te.bias.stop_gradient = no_dbias layer_te.bias.stop_gradient = no_dbias
...@@ -362,6 +368,7 @@ class TestLayerNormLinear: ...@@ -362,6 +368,7 @@ class TestLayerNormLinear:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
...@@ -378,9 +385,10 @@ class TestLayerNormLinear: ...@@ -378,9 +385,10 @@ class TestLayerNormLinear:
@pytest.mark.parametrize('do_calibration', [True, False]) @pytest.mark.parametrize('do_calibration', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
no_wgrad, fp8_wgrad, do_calibration, return_ln_out, no_wgrad, fp8_wgrad, do_calibration, return_ln_out,
activation_dtype): activation_dtype, normalization):
""" """
Test FP8 LayerNormLinear Layer Test FP8 LayerNormLinear Layer
""" """
...@@ -392,6 +400,7 @@ class TestLayerNormLinear: ...@@ -392,6 +400,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
...@@ -399,6 +408,7 @@ class TestLayerNormLinear: ...@@ -399,6 +408,7 @@ class TestLayerNormLinear:
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
eps=eps, eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
) )
...@@ -407,12 +417,14 @@ class TestLayerNormLinear: ...@@ -407,12 +417,14 @@ class TestLayerNormLinear:
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
eps=eps, eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend='paddle',
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True) layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.weight.copy_(layer_te.weight.T, True) layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias: if has_bias:
...@@ -420,9 +432,10 @@ class TestLayerNormLinear: ...@@ -420,9 +432,10 @@ class TestLayerNormLinear:
layer_te.weight.stop_gradient = no_wgrad layer_te.weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad layer_te.ln_weight.stop_gradient = no_wgrad
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.weight.stop_gradient = no_wgrad layer_pd.weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_te.bias.stop_gradient = no_dbias layer_te.bias.stop_gradient = no_dbias
...@@ -444,6 +457,7 @@ class TestLayerNormLinear: ...@@ -444,6 +457,7 @@ class TestLayerNormLinear:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
...@@ -523,8 +537,11 @@ class TestLayerNormMLP: ...@@ -523,8 +537,11 @@ class TestLayerNormMLP:
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
no_wgrad, return_ln_out, activation_dtype): no_wgrad, return_ln_out, activation_dtype, normalization,
activation):
""" """
Tests for TestLayerNormMLP layer Tests for TestLayerNormMLP layer
""" """
...@@ -536,11 +553,14 @@ class TestLayerNormMLP: ...@@ -536,11 +553,14 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
layer_te = te.LayerNormMLP( layer_te = te.LayerNormMLP(
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
eps=eps, eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
) )
...@@ -548,11 +568,14 @@ class TestLayerNormMLP: ...@@ -548,11 +568,14 @@ class TestLayerNormMLP:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
eps=eps, eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend='paddle',
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True) layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
...@@ -563,10 +586,11 @@ class TestLayerNormMLP: ...@@ -563,10 +586,11 @@ class TestLayerNormMLP:
layer_te.fc1_weight.stop_gradient = no_wgrad layer_te.fc1_weight.stop_gradient = no_wgrad
layer_te.fc2_weight.stop_gradient = no_wgrad layer_te.fc2_weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad layer_te.ln_weight.stop_gradient = no_wgrad
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.fc1_weight.stop_gradient = no_wgrad layer_pd.fc1_weight.stop_gradient = no_wgrad
layer_pd.fc2_weight.stop_gradient = no_wgrad layer_pd.fc2_weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_te.fc1_bias.stop_gradient = no_dbias layer_te.fc1_bias.stop_gradient = no_dbias
...@@ -595,6 +619,7 @@ class TestLayerNormMLP: ...@@ -595,6 +619,7 @@ class TestLayerNormMLP:
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.fc1_bias.grad, assert_allclose(layer_te.fc1_bias.grad,
...@@ -618,9 +643,11 @@ class TestLayerNormMLP: ...@@ -618,9 +643,11 @@ class TestLayerNormMLP:
@pytest.mark.parametrize('do_calibration', [True, False]) @pytest.mark.parametrize('do_calibration', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
no_wgrad, fp8_wgrad, do_calibration, return_ln_out, no_wgrad, fp8_wgrad, do_calibration, return_ln_out, activation_dtype,
activation_dtype): normalization, activation):
""" """
Test FP8 LayerNormMLP Layer Test FP8 LayerNormMLP Layer
""" """
...@@ -632,6 +659,7 @@ class TestLayerNormMLP: ...@@ -632,6 +659,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
...@@ -639,6 +667,8 @@ class TestLayerNormMLP: ...@@ -639,6 +667,8 @@ class TestLayerNormMLP:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
eps=eps, eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
) )
...@@ -647,11 +677,14 @@ class TestLayerNormMLP: ...@@ -647,11 +677,14 @@ class TestLayerNormMLP:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
eps=eps, eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend='paddle',
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True) layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
...@@ -662,10 +695,11 @@ class TestLayerNormMLP: ...@@ -662,10 +695,11 @@ class TestLayerNormMLP:
layer_te.fc1_weight.stop_gradient = no_wgrad layer_te.fc1_weight.stop_gradient = no_wgrad
layer_te.fc2_weight.stop_gradient = no_wgrad layer_te.fc2_weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad layer_te.ln_weight.stop_gradient = no_wgrad
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.fc1_weight.stop_gradient = no_wgrad layer_pd.fc1_weight.stop_gradient = no_wgrad
layer_pd.fc2_weight.stop_gradient = no_wgrad layer_pd.fc2_weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_te.fc1_bias.stop_gradient = no_dbias layer_te.fc1_bias.stop_gradient = no_dbias
...@@ -696,6 +730,7 @@ class TestLayerNormMLP: ...@@ -696,6 +730,7 @@ class TestLayerNormMLP:
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.fc1_bias.grad, assert_allclose(layer_te.fc1_bias.grad,
...@@ -782,9 +817,9 @@ class TestLayerNormMLP: ...@@ -782,9 +817,9 @@ class TestLayerNormMLP:
atol=atol) atol=atol)
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) @pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross']) @pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
...@@ -881,19 +916,21 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -881,19 +916,21 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16]) @pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output): math_dtype, output_layernorm, return_layernorm_output,
normalization):
""" """
Test Transformer Encoder Layer Test Transformer Encoder Layer
""" """
...@@ -901,6 +938,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -901,6 +938,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2 rtol = 5e-2
atol = 5e-2 atol = 5e-2
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
...@@ -945,6 +983,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -945,6 +983,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
apply_residual_connection_post_layernorm=return_layernorm_output, apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
layer_type='encoder', layer_type='encoder',
normalization=normalization,
backend='transformer_engine') backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size, layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -959,6 +998,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -959,6 +998,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
apply_residual_connection_post_layernorm=return_layernorm_output, apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
layer_type='encoder', layer_type='encoder',
normalization=normalization,
backend='paddle') backend='paddle')
# MultiHeadAttention params # MultiHeadAttention params
...@@ -973,16 +1013,17 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -973,16 +1013,17 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
else: else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True) layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.weight.copy_( layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True) layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_( layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True) layer_te.self_attention.layernorm_qkv.bias, True)
...@@ -999,17 +1040,18 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -999,17 +1040,18 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# LayerNorm MLP params # LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
...@@ -1062,10 +1104,10 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1062,10 +1104,10 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
atol=0.5) atol=0.5)
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16]) @pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
...@@ -1073,10 +1115,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1073,10 +1115,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False]) @pytest.mark.parametrize('recompute_core_attention', [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output, math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention): recompute_core_attention, normalization):
""" """
Test Transformer Decoder Layer Test Transformer Decoder Layer
""" """
...@@ -1084,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1084,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2 rtol = 5e-2
atol = 6e-2 atol = 6e-2
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
...@@ -1137,6 +1181,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1137,6 +1181,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
apply_residual_connection_post_layernorm=return_layernorm_output, apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
layer_type='decoder', layer_type='decoder',
normalization=normalization,
backend='transformer_engine') backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size, layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -1151,6 +1196,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1151,6 +1196,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
apply_residual_connection_post_layernorm=return_layernorm_output, apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
layer_type='decoder', layer_type='decoder',
normalization=normalization,
backend='paddle') backend='paddle')
# MultiHeadAttention params - self attn # MultiHeadAttention params - self attn
...@@ -1165,16 +1211,17 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1165,16 +1211,17 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
else: else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True) layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.weight.copy_( layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True) layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_( layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True) layer_te.self_attention.layernorm_qkv.bias, True)
...@@ -1192,16 +1239,17 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1192,16 +1239,17 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# MultiHeadAttention params - cross attn # MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_( layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
layer_te.inter_attention.layernorm_query.ln_weight, True) layer_te.inter_attention.layernorm_query.ln_weight, True)
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True)
layer_pd.inter_attention.layernorm_query.weight.copy_( layer_pd.inter_attention.layernorm_query.weight.copy_(
layer_te.inter_attention.layernorm_query.weight.T, True) layer_te.inter_attention.layernorm_query.weight.T, True)
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True)
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_( layer_pd.inter_attention.layernorm_query.bias.copy_(
layer_te.inter_attention.layernorm_query.bias, True) layer_te.inter_attention.layernorm_query.bias, True)
...@@ -1225,17 +1273,18 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1225,17 +1273,18 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# LayerNorm MLP params # LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
......
...@@ -5,12 +5,6 @@ ...@@ -5,12 +5,6 @@
import struct import struct
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -34,6 +28,10 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -34,6 +28,10 @@ from transformer_engine.paddle.cpp_extensions import (
cast_transpose_bgrad, cast_transpose_bgrad,
te_gelu, te_gelu,
gelu_fp8, gelu_fp8,
swiglu,
swiglu_fp8,
swiglu_pd,
dswiglu,
dgelu_cast_transpose_bgrad_fp8, dgelu_cast_transpose_bgrad_fp8,
layernorm_fwd_fp8, layernorm_fwd_fp8,
layernorm_fwd, layernorm_fwd,
...@@ -62,9 +60,9 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, ...@@ -62,9 +60,9 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)] SELF_ATTN_CASES = [(2, 512, 12, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)] CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)]
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)] FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
...@@ -296,6 +294,55 @@ class TestActivation: ...@@ -296,6 +294,55 @@ class TestActivation:
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01) assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bf16():
"""
Test BF16 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
swiglu_out = cast_from_fp8(swiglu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bwd():
"""
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
class TestGemm: class TestGemm:
""" """
......
...@@ -4,6 +4,15 @@ ...@@ -4,6 +4,15 @@
"""Transformer Engine bindings for Paddle""" """Transformer Engine bindings for Paddle"""
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax, from .layer import (
DotProductAttention, MultiHeadAttention, TransformerLayer) Linear,
LayerNorm,
LayerNormLinear,
LayerNormMLP,
FusedScaleMaskSoftmax,
DotProductAttention,
MultiHeadAttention,
TransformerLayer,
RotaryPositionEmbedding,
)
from .recompute import recompute from .recompute import recompute
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta from .fp8 import FP8TensorMeta
...@@ -328,6 +329,56 @@ def gelu_fp8( ...@@ -328,6 +329,56 @@ def gelu_fp8(
return out return out
def swiglu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 SWIGLU"""
return tex.te_swiglu(
inp,
int(otype),
)
def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor:
"""Native SWIGLU"""
gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
out = F.silu(gate_out) * up_out
return out
def swiglu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""SWIGLU + FP8 cast"""
out, _, _ = tex.te_swiglu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return out
def dswiglu(
grad_output: paddle.Tensor,
swiglu_input: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""dSWIGLU"""
return tex.te_dswiglu(
grad_output,
swiglu_input,
int(otype),
)
def dgelu_cast_transpose_bgrad_fp8( def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor, grad_output: paddle.Tensor,
gelu_input: paddle.Tensor, gelu_input: paddle.Tensor,
...@@ -404,9 +455,10 @@ def rmsnorm_fwd( ...@@ -404,9 +455,10 @@ def rmsnorm_fwd(
eps: float, eps: float,
otype: tex.DType, otype: tex.DType,
sm_margin: int = 0, sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward""" """Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin) return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)
def rmsnorm_fwd_fp8( def rmsnorm_fwd_fp8(
...@@ -417,12 +469,13 @@ def rmsnorm_fwd_fp8( ...@@ -417,12 +469,13 @@ def rmsnorm_fwd_fp8(
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int = 0, sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output""" """RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale, out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, fp8_tensor.value, fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
int(otype), sm_margin) int(otype), sm_margin, zero_centered_gamma)
return out, rsigma return out, rsigma
...@@ -432,9 +485,10 @@ def rmsnorm_bwd( ...@@ -432,9 +485,10 @@ def rmsnorm_bwd(
rsigma: paddle.Tensor, rsigma: paddle.Tensor,
gamma: paddle.Tensor, gamma: paddle.Tensor,
sm_margin: int = 0, sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward""" """Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin) return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)
def mask_to_cu_seqlens( def mask_to_cu_seqlens(
......
...@@ -218,6 +218,66 @@ std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype) ...@@ -218,6 +218,66 @@ std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype)
return {output}; return {output};
} }
std::vector<paddle::Tensor> te_swiglu(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input,
int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype()));
auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype()));
auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype()));
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output, std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output,
const paddle::Tensor &gelu_input, const paddle::Tensor &gelu_input,
const paddle::Tensor &scale, const paddle::Tensor &scale,
...@@ -406,7 +466,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad ...@@ -406,7 +466,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input, std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype, const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin) { int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input); auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
...@@ -448,14 +510,16 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input, ...@@ -448,14 +510,16 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
paddle::Tensor &amax, // NOLINT paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype, float eps, int64_t index, int64_t otype,
int64_t sm_margin) { int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input); auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0]; size_t N = shape[0];
size_t H = shape[1]; size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto rsigma = auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place()); paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input); auto input_cu = MakeNvteTensor(input);
...@@ -487,7 +551,10 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input, ...@@ -487,7 +551,10 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &rsigma, const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin) { const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto dx = paddle::empty_like(x, x.dtype(), x.place()); auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
...@@ -1374,6 +1441,25 @@ PD_BUILD_OP(te_gelu) ...@@ -1374,6 +1441,25 @@ PD_BUILD_OP(te_gelu)
.Attrs({"otype: int64_t"}) .Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu));
PD_BUILD_OP(te_swiglu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu));
PD_BUILD_OP(te_swiglu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8));
PD_BUILD_OP(te_dswiglu)
.Inputs({"Grad", "Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu));
PD_BUILD_OP(te_cast_transpose_bgrad_dgelu) PD_BUILD_OP(te_cast_transpose_bgrad_dgelu)
.Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"}) .Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"}) .Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"})
...@@ -1404,20 +1490,21 @@ PD_BUILD_OP(te_layernorm_bwd) ...@@ -1404,20 +1490,21 @@ PD_BUILD_OP(te_layernorm_bwd)
PD_BUILD_OP(te_rmsnorm_fwd) PD_BUILD_OP(te_rmsnorm_fwd)
.Inputs({"Input", "Weight"}) .Inputs({"Input", "Weight"})
.Outputs({"Output", "InvVariance"}) .Outputs({"Output", "InvVariance"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t"}) .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd));
PD_BUILD_OP(te_rmsnorm_fwd_fp8) PD_BUILD_OP(te_rmsnorm_fwd_fp8)
.Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"}) .Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "InvVariance", "Amax", "ScaleInv"}) .Outputs({"Output", "InvVariance", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t"}) .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8));
PD_BUILD_OP(te_rmsnorm_bwd) PD_BUILD_OP(te_rmsnorm_bwd)
.Inputs({"Dz", "X", "Rsigma", "Gamma"}) .Inputs({"Dz", "X", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma"}) .Outputs({"Dx", "Dgamma"})
.Attrs({"sm_margin: int64_t"}) .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd));
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Layer level Paddle APIs""" """Layer level Paddle APIs"""
from .attention import DotProductAttention, MultiHeadAttention from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
......
...@@ -10,6 +10,10 @@ from typing import Optional, Tuple, Union ...@@ -10,6 +10,10 @@ from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
...@@ -30,7 +34,7 @@ from ..distributed import get_tp_group_and_world_size, track_rng_state ...@@ -30,7 +34,7 @@ from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide from ..utils import attention_mask_func, divide
from ..recompute import recompute from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"] __all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
...@@ -47,6 +51,81 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: ...@@ -47,6 +51,81 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size]) return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])
class RotaryPositionEmbedding(paddle.nn.Layer):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
max_position_embeddings: int,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
max_position_embeddings: int
max_position_embeddings before position interpolation
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') /
self.dim))
self._set_cos_sin_cache(seq_len=max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype="float32")
# [seq_len, dim/2]
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
def forward(self, max_seq_len: int):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
"""
cos = self.cos_cached[:, :, :max_seq_len, ...]
sin = self.sin_cached[:, :, :max_seq_len, ...]
return (cos, sin)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
"""Applies rotary positional embedding to the input."""
if position_ids is None:
# Note: Only for LlamaForCausalLMPipe model pretraining
cos = cos[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input""" """Function for FusedAttention with packed QKV input"""
...@@ -450,6 +529,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -450,6 +529,8 @@ class MultiHeadAttention(paddle.nn.Layer):
whether to apply layernorm to the input. whether to apply layernorm to the input.
attention_type: {'self', 'cross'}, default = `self` attention_type: {'self', 'cross'}, default = `self`
type of attention operation. type of attention operation.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
zero_centered_gamma: bool, default = `False` zero_centered_gamma: bool, default = `False`
whether to zero initialize the gamma of the layernorm operation. whether to zero initialize the gamma of the layernorm operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
...@@ -491,11 +572,13 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -491,11 +572,13 @@ class MultiHeadAttention(paddle.nn.Layer):
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
max_sequence_length: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None, params_dtype: Optional[paddle.dtype] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
input_layernorm: bool = False, input_layernorm: bool = False,
attention_type: str = "self", attention_type: str = "self",
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -509,6 +592,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -509,6 +592,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.attention_type = attention_type self.attention_type = attention_type
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
self.max_sequence_length = max_sequence_length
self.weight_attr = weight_attr self.weight_attr = weight_attr
self.bias_attr = bias_attr self.bias_attr = bias_attr
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
...@@ -544,6 +628,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -544,6 +628,7 @@ class MultiHeadAttention(paddle.nn.Layer):
weight_attr=self.weight_attr, weight_attr=self.weight_attr,
bias_attr=self.bias_attr, bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
...@@ -571,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -571,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer):
weight_attr=self.weight_attr, weight_attr=self.weight_attr,
bias_attr=self.bias_attr, bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
...@@ -628,6 +714,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -628,6 +714,7 @@ class MultiHeadAttention(paddle.nn.Layer):
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None,
rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
...@@ -645,6 +732,9 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -645,6 +732,9 @@ class MultiHeadAttention(paddle.nn.Layer):
Boolean tensor used to mask out softmax input when not using attention. Boolean tensor used to mask out softmax input when not using attention.
encoder_output : Optional[paddle.Tensor], default = `None` encoder_output : Optional[paddle.Tensor], default = `None`
Output of the encoder layer. Output of the encoder layer.
rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
only support no_bias type currently, {`no_bias`} only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None` core_attention_bias: Optional[paddle.Tensor], default = `None`
...@@ -675,8 +765,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -675,8 +765,8 @@ class MultiHeadAttention(paddle.nn.Layer):
if input_dim == 2: if input_dim == 2:
# hidden_states: [b * s_q, hidden_size] # hidden_states: [b * s_q, hidden_size]
# need to get max_seq_len from attention_mask # need to get max_seq_len from attention_mask
assert attention_mask is not None assert self.max_sequence_length is not None, "max_sequence_length must be provided"
max_seq_len = attention_mask.shape[-1] max_seq_len = self.max_sequence_length
elif input_dim == 3: elif input_dim == 3:
# hidden_states: [b, s_q, hidden_size] # hidden_states: [b, s_q, hidden_size]
max_seq_len = hidden_states.shape[1] max_seq_len = hidden_states.shape[1]
...@@ -723,30 +813,6 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -723,30 +813,6 @@ class MultiHeadAttention(paddle.nn.Layer):
shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer)) for x in (query_layer, key_layer, value_layer))
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
context_layer = recompute(
self.core_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
else: # cross attention else: # cross attention
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
...@@ -785,6 +851,23 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -785,6 +851,23 @@ class MultiHeadAttention(paddle.nn.Layer):
-1, max_seq_len, self.num_attention_heads_per_partition, -1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head self.hidden_size_per_attention_head
]) ])
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if fused_rotary_position_embedding is None:
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb,
k_pos_emb)
else:
query_layer, key_layer, _ = fused_rotary_position_embedding(
query_layer,
key_layer,
v=None,
sin=k_pos_emb,
cos=q_pos_emb,
position_ids=None,
use_neox_rotary_style=False,
)
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention: if recompute_core_attention:
context_layer = recompute( context_layer = recompute(
......
...@@ -17,6 +17,9 @@ from ..cpp_extensions import ( ...@@ -17,6 +17,9 @@ from ..cpp_extensions import (
layernorm_fwd, layernorm_fwd,
layernorm_fwd_fp8, layernorm_fwd_fp8,
layernorm_bwd, layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
) )
from .base import TransformerEngineBaseLayer from .base import TransformerEngineBaseLayer
...@@ -44,82 +47,129 @@ from ..utils import ( ...@@ -44,82 +47,129 @@ from ..utils import (
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
def _layernorm_fwd_fp8_cast( def _apply_normalization_fwd(
normalization: str,
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
ln_weight: paddle.Tensor, norm_weight: paddle.Tensor,
ln_bias: paddle.Tensor, norm_bias: Union[paddle.Tensor, None],
out_fp8_index: FP8FwdTensors, out_fp8_index: FP8FwdTensors,
eps: float, eps: float,
fp8_enabled: bool, fp8_enabled: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
return_layernorm_output: bool, return_norm_output: bool,
fwd_ln_sm_margin: int, fwd_norm_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
): ):
"""Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path""" """Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path"""
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
ln_weight = cast_if_needed_inplace(ln_weight, activation_dtype) if normalization == "RMSNorm":
ln_bias = cast_if_needed_inplace(ln_bias, activation_dtype) assert norm_bias is None, "RMSNorm does not support bias!"
norm_weight = cast_if_needed_inplace(norm_weight, activation_dtype)
if norm_bias is not None:
norm_bias = cast_if_needed_inplace(norm_bias, activation_dtype)
norm_kwargs = {
"inp": inputmat,
"weight": norm_weight,
"eps": eps,
"otype": TE_DType[activation_dtype],
"sm_margin": fwd_norm_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
}
fwd_normalization_funcs = {
('LayerNorm', True, True): layernorm_fwd,
('LayerNorm', True, False): layernorm_fwd_fp8,
('LayerNorm', False, True): layernorm_fwd,
('LayerNorm', False, False): layernorm_fwd,
('RMSNorm', True, True): rmsnorm_fwd,
('RMSNorm', True, False): rmsnorm_fwd_fp8,
('RMSNorm', False, True): rmsnorm_fwd,
('RMSNorm', False, False): rmsnorm_fwd,
}
if normalization == "LayerNorm":
norm_kwargs["bias"] = norm_bias
norm_fwd_func = fwd_normalization_funcs[(normalization, fp8_enabled, return_norm_output)]
if fp8_enabled: if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_norm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8( fp8_kwargs = {
inputmat, "fp8_meta_tensor": fp8_meta["scaling_fwd"],
ln_weight, "fp8_tensor": out_fp8_index,
ln_bias, "otype": fp8_dtype_forward,
eps, }
norm_kwargs.update(fp8_kwargs)
out_tuple = norm_fwd_func(**norm_kwargs)
if normalization == "LayerNorm":
norm_out_return, mu, rsigma = out_tuple
else: # RMSNorm
norm_out_return, rsigma = out_tuple
mu = None
if fp8_enabled and return_norm_output:
norm_out = cast_to_fp8(
norm_out_return,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
out_fp8_index, out_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
) )
ln_out_return = ln_out
else: else:
ln_out_return, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps, norm_out = norm_out_return
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
)
else:
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out_return = ln_out
return ( return (
ln_out_return, norm_out_return,
ln_out, norm_out,
mu, mu,
rsigma, rsigma,
) )
def _layernorm_bwd( def _apply_normalization_bwd(
normalization: str,
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
dgrad: paddle.Tensor, dgrad: paddle.Tensor,
ln_weight: paddle.Tensor, norm_weight: paddle.Tensor,
mu: paddle.Tensor, mu: Union[paddle.Tensor, None],
rsigma: paddle.Tensor, rsigma: paddle.Tensor,
grad_ln_out_return: paddle.Tensor, grad_norm_out_return: paddle.Tensor,
return_layernorm_output: bool, return_norm_output: bool,
bwd_ln_sm_margin: int, bwd_norm_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
): ):
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
if normalization == "RMSNorm":
assert mu is None, "RMSNorm does not support bias!"
# LayerNorm gradient # LayerNorm gradient
d_ln_out = dgrad.reshape(inputmat.shape) d_norm_out = dgrad.reshape(inputmat.shape)
# Residual gradient # Residual gradient
if return_layernorm_output: if return_norm_output:
d_ln_out = d_ln_out + grad_ln_out_return.reshape(d_ln_out.shape) d_norm_out = d_norm_out + grad_norm_out_return.reshape(d_norm_out.shape)
return layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight, bwd_ln_sm_margin, norm_bwd_func = layernorm_bwd if normalization == "LayerNorm" else rmsnorm_bwd
zero_centered_gamma) norm_bwd_kwargs = {
"dz": d_norm_out,
"x": inputmat,
"rsigma": rsigma,
"gamma": norm_weight,
"sm_margin": bwd_norm_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
}
if normalization == "LayerNorm":
norm_bwd_kwargs["mu"] = mu
out_tuple = norm_bwd_func(**norm_bwd_kwargs)
if normalization == "LayerNorm":
dxmat, dgamma, dbeta = out_tuple
else: # RMSNorm
dxmat, dgamma = out_tuple
dbeta = None
return dxmat, dgamma, dbeta
class _LayerNormLinear(paddle.autograd.PyLayer): class _LayerNormLinear(paddle.autograd.PyLayer):
...@@ -130,7 +180,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -130,7 +180,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx, ctx,
inp: paddle.Tensor, inp: paddle.Tensor,
ln_weight: paddle.Tensor, ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor, ln_bias: Union[paddle.Tensor, None],
weight: paddle.Tensor, weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor], weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor], weight_t_fp8: Optional[paddle.Tensor],
...@@ -146,6 +196,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -146,6 +196,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
normalization: str,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
...@@ -153,6 +204,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -153,6 +204,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
tp_size: int, tp_size: int,
is_first_microbatch: bool, is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
assert ln_bias is None, "RMSNorm does not support bias!"
else: # LayerNorm
assert ln_bias is not None, "LayerNorm requires bias!"
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.shape[0] in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
...@@ -167,7 +222,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -167,7 +222,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_out, ln_out,
mu, mu,
rsigma, rsigma,
) = _layernorm_fwd_fp8_cast( ) = _apply_normalization_fwd(
normalization,
inputmat, inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
...@@ -232,9 +288,11 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -232,9 +288,11 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.has_ln_bias = ln_bias is not None
ctx.normalization = normalization
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -314,7 +372,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -314,7 +372,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad_ bgrad = bgrad_
# LayerNorm Bwd # LayerNorm Bwd
dxmat, dgamma, dbeta = _layernorm_bwd( dxmat, dgamma, dbeta = _apply_normalization_bwd(
ctx.normalization,
inputmat, inputmat,
dgrad, dgrad,
ln_weight, ln_weight,
...@@ -328,6 +387,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -328,6 +387,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad if ctx.requires_bgrad else None bgrad = bgrad if ctx.requires_bgrad else None
bgrad_out = (bgrad,) if ctx.use_bias else () bgrad_out = (bgrad,) if ctx.use_bias else ()
dbeta = dbeta if ctx.requires_ln_bgrad else None
dbeta_out = (dbeta,) if ctx.has_ln_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None: if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = () weight_cache_grad = ()
...@@ -338,7 +399,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -338,7 +399,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, *dbeta_out,
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
*weight_cache_grad, *weight_cache_grad,
*bgrad_out, *bgrad_out,
...@@ -361,6 +422,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -361,6 +422,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
optional `paddle.ParamAttr` for weight. optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias. optional `paddle.ParamAttr` for bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation. together with the output of the linear transformation.
...@@ -395,6 +458,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -395,6 +458,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
eps: float = 1e-5, eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
normalization: str = 'LayerNorm',
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
...@@ -407,6 +471,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -407,6 +471,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.eps = eps self.eps = eps
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.backend = backend self.backend = backend
...@@ -439,16 +505,19 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -439,16 +505,19 @@ class LayerNormLinear(TransformerEngineBaseLayer):
dtype=self._dtype, dtype=self._dtype,
is_bias=False, is_bias=False,
) )
if self.normalization != "RMSNorm":
self.ln_bias = self.create_parameter( self.ln_bias = self.create_parameter(
shape=[self.in_features], shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)), attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
) )
else:
self.ln_bias = None
if self.sequence_parallel: if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight) mark_as_sequence_parallel_parameter(self.ln_weight)
if self.ln_bias is not None:
mark_as_sequence_parallel_parameter(self.ln_bias) mark_as_sequence_parallel_parameter(self.ln_bias)
# Initialize Linear weight parameter # Initialize Linear weight parameter
...@@ -534,6 +603,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -534,6 +603,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization,
self.parallel_mode, self.parallel_mode,
self.tensor_parallel, self.tensor_parallel,
self.sequence_parallel, self.sequence_parallel,
...@@ -566,19 +636,24 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -566,19 +636,24 @@ class LayerNormLinear(TransformerEngineBaseLayer):
warnings.warn( warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.") "`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp, if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
bias=self.ln_bias, bias=self.ln_bias,
epsilon=self.eps) epsilon=self.eps)
if self.parallel_mode == 'column' and self.tensor_parallel: if self.parallel_mode == 'column' and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group) norm_out = identity(norm_out, self.tp_group)
out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None) out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel: if self.parallel_mode == 'row' and self.tensor_parallel:
out, _ = allreduce(out, self.tp_group) out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output: if self.return_layernorm_output:
return out, ln_out return out, norm_out
return out return out
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
......
...@@ -12,13 +12,17 @@ import paddle.nn.functional as F ...@@ -12,13 +12,17 @@ import paddle.nn.functional as F
from paddle.nn.initializer import Constant from paddle.nn.initializer import Constant
from .base import TransformerEngineBaseLayer from .base import TransformerEngineBaseLayer
from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd from .layernorm_linear import _apply_normalization_fwd, _apply_normalization_bwd
from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type
from ..cpp_extensions import ( from ..cpp_extensions import (
cast_from_fp8, cast_from_fp8,
dgelu_cast_transpose_bgrad_fp8,
gelu_fp8, gelu_fp8,
swiglu_fp8,
swiglu,
dswiglu,
cast_transpose_bgrad,
dgelu_cast_transpose_bgrad_fp8,
) )
from ..distributed import ( from ..distributed import (
allreduce, allreduce,
...@@ -91,13 +95,22 @@ def _mlp_forward( ...@@ -91,13 +95,22 @@ def _mlp_forward(
is_grad_enabled, is_grad_enabled,
is_first_microbatch, is_first_microbatch,
) )
if activation == "gelu":
gelu_out = gelu_fp8( gelu_out = gelu_fp8(
fc1_out, fc1_out,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
fc2_input_fp8_index, fc2_input_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
) )
elif activation == "swiglu":
gelu_out = swiglu_fp8(
fc1_out,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
)
else:
raise NotImplementedError("Activation type " + activation + " is not supported!")
fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8( fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8(
gelu_out, gelu_out,
...@@ -118,7 +131,7 @@ def _mlp_forward( ...@@ -118,7 +131,7 @@ def _mlp_forward(
is_first_microbatch, is_first_microbatch,
) )
else: else:
fc1_out, gelu_out = _linear_fwd_non_fp8( fc1_outputs = _linear_fwd_non_fp8(
inputmat, inputmat,
inputmat_fp8_index, inputmat_fp8_index,
fc1_weight, fc1_weight,
...@@ -135,6 +148,14 @@ def _mlp_forward( ...@@ -135,6 +148,14 @@ def _mlp_forward(
activation=activation, activation=activation,
) )
if activation == "gelu":
fc1_out, gelu_out = fc1_outputs
elif activation == "swiglu":
fc1_out = fc1_outputs
gelu_out = swiglu(fc1_out, TE_DType[activation_dtype])
else:
raise NotImplementedError("Activation type " + activation + " is not supported!")
fc2_out = _linear_fwd_non_fp8( fc2_out = _linear_fwd_non_fp8(
gelu_out, gelu_out,
fc2_input_fp8_index, fc2_input_fp8_index,
...@@ -234,6 +255,7 @@ def _mlp_backward( ...@@ -234,6 +255,7 @@ def _mlp_backward(
tp_group, tp_group,
) )
if activation == "gelu":
# GELU Bwd # GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
fc2_dgrad, fc2_dgrad,
...@@ -242,6 +264,14 @@ def _mlp_backward( ...@@ -242,6 +264,14 @@ def _mlp_backward(
fc1_grad_output_fp8_index, fc1_grad_output_fp8_index,
fp8_dtype_backward, fp8_dtype_backward,
) )
elif activation == "swiglu":
dgelu = dswiglu(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype])
fc1_bgrad_, dgelu, dgelu_t = cast_transpose_bgrad(
dgelu,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
)
if requires_fc1_bgrad: if requires_fc1_bgrad:
fc1_bgrad = fc1_bgrad_ fc1_bgrad = fc1_bgrad_
...@@ -301,6 +331,10 @@ def _mlp_backward( ...@@ -301,6 +331,10 @@ def _mlp_backward(
gelu_input=fc1_out, gelu_input=fc1_out,
activation=activation, activation=activation,
) )
if activation == "swiglu":
dgelu = dswiglu(dgelu, fc1_out, TE_DType[dgelu.dtype])
fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8( fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8(
fc1_input, fc1_input,
fc1_weight, fc1_weight,
...@@ -331,7 +365,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -331,7 +365,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx, ctx,
inp: paddle.Tensor, inp: paddle.Tensor,
ln_weight: paddle.Tensor, ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor, ln_bias: Union[paddle.Tensor, None],
fc1_weight: paddle.Tensor, fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor], fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor], fc1_weight_t_fp8: Optional[paddle.Tensor],
...@@ -352,6 +386,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -352,6 +386,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
normalization: str,
activation: str, activation: str,
set_parallel_mode: bool, set_parallel_mode: bool,
tensor_parallel: bool, tensor_parallel: bool,
...@@ -360,6 +395,10 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -360,6 +395,10 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
tp_size: int, tp_size: int,
is_first_microbatch: bool, is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
assert ln_bias is None, "RMSNorm does not support bias!"
else: # LayerNorm
assert ln_bias is not None, "LayerNorm requires bias!"
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.shape[0] in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
...@@ -370,7 +409,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -370,7 +409,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
assert_dim_for_fp8_forward_exec(fc2_weight) assert_dim_for_fp8_forward_exec(fc2_weight)
# only support gelu for now # only support gelu for now
assert activation == 'gelu' assert activation in ["gelu", "swiglu"], "Only gelu and swiglu are supported for now"
# LayerNorm Fwd + FP8 Cast # LayerNorm Fwd + FP8 Cast
( (
...@@ -378,7 +417,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -378,7 +417,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_out, ln_out,
mu, mu,
rsigma, rsigma,
) = _layernorm_fwd_fp8_cast( ) = _apply_normalization_fwd(
normalization,
inputmat, inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
...@@ -463,9 +503,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -463,9 +503,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient
ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.has_ln_bias = ln_bias is not None
ctx.normalization = normalization
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
...@@ -549,7 +591,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -549,7 +591,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad = fc2_bgrad_ fc2_bgrad = fc2_bgrad_
# LayerNorm Bwd # LayerNorm Bwd
dxmat, dgamma, dbeta = _layernorm_bwd( dxmat, dgamma, dbeta = _apply_normalization_bwd(
ctx.normalization,
inputmat, inputmat,
fc1_dgrad, fc1_dgrad,
ln_weight, ln_weight,
...@@ -565,6 +608,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -565,6 +608,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None
fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else () fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else ()
fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else () fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else ()
dbeta = dbeta if ctx.requires_ln_bgrad else None
dbeta_out = (dbeta,) if ctx.has_ln_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None: if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
fc1_weight_cache_grad = () fc1_weight_cache_grad = ()
...@@ -577,7 +622,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -577,7 +622,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, *dbeta_out,
fc1_wgrad if ctx.requires_fc1_wgrad else None, fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_weight_cache_grad, *fc1_weight_cache_grad,
*fc1_bgrad_out, *fc1_bgrad_out,
...@@ -604,6 +649,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -604,6 +649,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
optional `paddle.ParamAttr` for weight. optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias. optional `paddle.ParamAttr` for bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
...@@ -641,6 +688,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -641,6 +688,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
eps: float = 1e-5, eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
normalization: str = "LayerNorm",
activation: str = "gelu", activation: str = "gelu",
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
...@@ -654,6 +702,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -654,6 +702,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size self.ffn_hidden_size = ffn_hidden_size
self.eps = eps self.eps = eps
self.normalization = normalization
assert normalization in ["LayerNorm", "RMSNorm"], "Normalization type not supported"
self.activation = activation self.activation = activation
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
...@@ -684,22 +734,31 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -684,22 +734,31 @@ class LayerNormMLP(TransformerEngineBaseLayer):
is_bias=False, is_bias=False,
) )
if self.normalization != "RMSNorm":
self.ln_bias = self.create_parameter( self.ln_bias = self.create_parameter(
shape=[self.hidden_size], shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)), attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
) )
else:
self.ln_bias = None
if self.sequence_parallel: if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight) mark_as_sequence_parallel_parameter(self.ln_weight)
if self.ln_bias is not None:
mark_as_sequence_parallel_parameter(self.ln_bias) mark_as_sequence_parallel_parameter(self.ln_bias)
# FC1 weights # FC1 weights
if self.activation in ["swiglu"]:
fc1_output_features = self.size_per_partition * 2
else:
fc1_output_features = self.size_per_partition
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter( self.fc1_weight = self.create_parameter(
shape=[self.size_per_partition, self.hidden_size] if self.backend shape=[fc1_output_features, self.hidden_size] if self.backend
== 'transformer_engine' else [self.hidden_size, self.size_per_partition], == 'transformer_engine' else [self.hidden_size, fc1_output_features],
attr=self._weight_attr, attr=self._weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False, is_bias=False,
...@@ -717,7 +776,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -717,7 +776,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
if self.has_bias: if self.has_bias:
self.fc1_bias = self.create_parameter( self.fc1_bias = self.create_parameter(
shape=[self.size_per_partition], shape=[fc1_output_features],
attr=self._bias_attr, attr=self._bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
...@@ -809,6 +868,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -809,6 +868,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization,
self.activation, self.activation,
self.set_parallel_mode, self.set_parallel_mode,
self.tensor_parallel, self.tensor_parallel,
...@@ -842,14 +902,18 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -842,14 +902,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
warnings.warn( warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.") "`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp, if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
bias=self.ln_bias, bias=self.ln_bias,
epsilon=self.eps) epsilon=self.eps)
if self.set_parallel_mode and self.tensor_parallel: if self.set_parallel_mode and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group) norm_out = identity(norm_out, self.tp_group)
fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias) fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias)
act_func = get_paddle_act_func(self.activation) act_func = get_paddle_act_func(self.activation)
act_out = act_func(fc1_out) act_out = act_func(fc1_out)
out = F.linear(act_out, self.fc2_weight, out = F.linear(act_out, self.fc2_weight,
...@@ -858,7 +922,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -858,7 +922,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
out, _ = allreduce(out, self.tp_group) out, _ = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out out = out + self.fc2_bias if self.fc2_bias is not None else out
if self.return_layernorm_output: if self.return_layernorm_output:
return out, ln_out return out, norm_out
return out return out
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""RMSNorm API"""
import os
from typing import Union, Tuple
import paddle
from paddle.nn.initializer import Constant
from ..constants import TE_DType
from ..cpp_extensions import rmsnorm_fwd, rmsnorm_bwd
from ..distributed import mark_as_sequence_parallel_parameter
__all__ = ["RMSNorm"]
class _RMSNorm(paddle.autograd.PyLayer):
"""functional RMSNorm"""
@staticmethod
def forward(
ctx,
inp: paddle.Tensor,
rmsnorm_weight: paddle.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.shape[0]
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.reshape((-1, in_features))
rmsnorm_out, rsigma = rmsnorm_fwd(inputmat, rmsnorm_weight, eps, TE_DType[inp.dtype],
fwd_rmsnorm_sm_margin, zero_centered_gamma)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dx = not inp.stop_gradient
ctx.requires_dw = not rmsnorm_weight.stop_gradient
return rmsnorm_out.reshape(inp.shape)
@staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor()
d_rmsnorm_out = grad_output.reshape(inputmat.shape)
dxmat, dgamma = rmsnorm_bwd(d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None,
dgamma if ctx.requires_dw else None,
)
class RMSNorm(paddle.nn.Layer):
r"""
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{RMS_\varepsilon(x)} * \gamma
where
.. math::
RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
backend to use for rmsnorm operation.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
zero_centered_gamma: bool = False,
sequence_parallel: bool = False,
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.sequence_parallel = sequence_parallel
self.backend = backend
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
if not self._weight_attr:
self._weight_attr = paddle.ParamAttr(initializer=Constant(1.0))
self.weight = self.create_parameter(
shape=[hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor:
return _RMSNorm.apply(
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.zero_centered_gamma,
)
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support RMSNorm with zero_centered_gamma.")
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
y = inp * norm * self.weight
return y
def forward(self, *args, **kwargs):
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} not supported.")
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer""" """Transformer"""
from typing import Optional, Union from typing import Optional, Tuple, Union
import warnings import warnings
import paddle import paddle
...@@ -60,6 +60,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -60,6 +60,7 @@ class TransformerLayer(paddle.nn.Layer):
if set to `decoder`, an additional cross-attn block is added after self-attn. if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option. `encoder` option.
normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm`
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -111,11 +112,13 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -111,11 +112,13 @@ class TransformerLayer(paddle.nn.Layer):
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
max_sequence_length: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None, params_dtype: Optional[paddle.dtype] = None,
apply_residual_connection_post_layernorm: bool = False, apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False, output_layernorm: bool = False,
layer_type: str = "encoder", layer_type: str = "encoder",
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
activation: str = 'gelu', activation: str = 'gelu',
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
...@@ -158,9 +161,11 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -158,9 +161,11 @@ class TransformerLayer(paddle.nn.Layer):
common_attention_kwargs = { common_attention_kwargs = {
"params_dtype": params_dtype, "params_dtype": params_dtype,
"return_layernorm_output": apply_residual_connection_post_layernorm, "return_layernorm_output": apply_residual_connection_post_layernorm,
"normalization": normalization,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel, "sequence_parallel": self.sequence_parallel,
'max_sequence_length': max_sequence_length,
"tp_group": tp_group, "tp_group": tp_group,
"num_gqa_groups": num_gqa_groups, "num_gqa_groups": num_gqa_groups,
"rng_state_name": attention_dropout_rng_state_name, "rng_state_name": attention_dropout_rng_state_name,
...@@ -190,6 +195,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -190,6 +195,7 @@ class TransformerLayer(paddle.nn.Layer):
eps=layernorm_epsilon, eps=layernorm_epsilon,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
normalization=normalization,
activation=activation, activation=activation,
return_layernorm_output=apply_residual_connection_post_layernorm, return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
...@@ -223,6 +229,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -223,6 +229,7 @@ class TransformerLayer(paddle.nn.Layer):
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None,
enc_dec_attn_mask: Optional[paddle.Tensor] = None, enc_dec_attn_mask: Optional[paddle.Tensor] = None,
rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
...@@ -249,6 +256,9 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -249,6 +256,9 @@ class TransformerLayer(paddle.nn.Layer):
enc_dec_attn_mask : Optional[paddle.Tensor], default = `None` enc_dec_attn_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out inter-attention softmax input if using Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`. `layer_type="decoder"`.
rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
core_attention_bias: Optional[paddle.Tensor], default = `None` core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T
...@@ -284,6 +294,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -284,6 +294,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
set_zero=set_zero, set_zero=set_zero,
rotary_pos_emb=rotary_pos_emb,
recompute_core_attention=recompute_core_attention, recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
......
...@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from .cpp_extensions import swiglu_pd
def cast_if_needed(tensor: Union[paddle.Tensor, None], def cast_if_needed(tensor: Union[paddle.Tensor, None],
...@@ -48,6 +49,8 @@ def get_paddle_act_func(activation): ...@@ -48,6 +49,8 @@ def get_paddle_act_func(activation):
funcs = { funcs = {
'gelu': F.gelu, 'gelu': F.gelu,
'relu': F.relu, 'relu': F.relu,
'silu': F.silu,
'swiglu': swiglu_pd,
} }
if activation not in funcs: if activation not in funcs:
raise "Activation type " + activation + " is not supported." raise "Activation type " + activation + " is not supported."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment