Unverified Commit 7444946d authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add nn layer (#361)



* Add nn.layer: softmax, attention, transformer
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* code refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* code refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* update docs and set dropout=0.1
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* Update transformer_engine/paddle/layer/attention.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e4f9e767
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Test TE Paddle Layer-level APIs""" """Test TE Paddle Layer-level APIs"""
import math
import os import os
import pytest import pytest
from utils import assert_allclose from utils import assert_allclose
...@@ -605,3 +606,492 @@ class TestLayerNormMLP: ...@@ -605,3 +606,492 @@ class TestLayerNormMLP:
if do_calibration: if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type,
mask_type, math_dtype):
"""
Test DotProductAttention Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 1e-4
atol = 2e-2
head_size = hidden_size // num_heads
self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, 3, num_heads,
head_size)).astype(math_dtype)
cross_attn_q_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, num_heads,
head_size)).astype(math_dtype)
cross_attn_kv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, kv_seqlen, 2, num_heads,
head_size)).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
dtype='int32') if attn_type == 'cross' else q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype('float32')
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
norm_factor = math.sqrt(hidden_size // num_heads)
layer_te = te.DotProductAttention(norm_factor,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='transformer_engine')
layer_pd = te.DotProductAttention(norm_factor,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='paddle')
def calc_attn_output_and_grad(layer, q, kv, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None
out = layer(_q, _kv, mask)
out.backward(dout)
return out, _q.grad, _kv.grad if _kv is not None else None
if attn_type == 'self':
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask,
grad_out)
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None,
attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
q_grad = qkv_grad[:, :, 0]
k_grad = qkv_grad[:, :, 1]
v_grad = qkv_grad[:, :, 2]
q_grad_ref = qkv_grad_ref[:, :, 0]
k_grad_ref = qkv_grad_ref[:, :, 1]
v_grad_ref = qkv_grad_ref[:, :, 2]
else:
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
cross_attn_kv_input, attn_mask, grad_out)
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
cross_attn_kv_input, attn_mask,
grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
k_grad = kv_grad[:, :, 0]
v_grad = kv_grad[:, :, 1]
k_grad_ref = kv_grad_ref[:, :, 0]
v_grad_ref = kv_grad_ref[:, :, 1]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
for i in range(0, bs):
valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :]
valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i,
0:kv_actual_seqlen[i], :, :]
valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i,
0:kv_actual_seqlen[i], :, :]
assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol)
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output):
"""
Test Transformer Encoder Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 5e-2
atol = 5e-2
eps = 1e-3
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='encoder',
backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='encoder',
backend='paddle')
# MultiHeadAttention params
if output_layernorm:
layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
layer_te.self_attention.proj.bias.stop_gradient = no_dbias
# LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
if output_layernorm:
layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
layer_pd.layernorm.weight.stop_gradient = no_wgrad
layer_pd.layernorm.bias.stop_gradient = no_dbias
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
out = layer(_encoder_input, mask)
out.backward(dout)
return out, _encoder_input.grad
out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask,
grad_out)
out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=atol)
if not no_dbias:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output):
"""
Test Transformer Decoder Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 5e-2
atol = 5e-2
eps = 1e-3
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32')
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='decoder',
backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='decoder',
backend='paddle')
# MultiHeadAttention params - self attn
if output_layernorm:
layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
layer_te.self_attention.proj.bias.stop_gradient = no_dbias
# MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
layer_te.inter_attention.layernorm_query.ln_weight, True)
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True)
layer_pd.inter_attention.layernorm_query.weight.copy_(
layer_te.inter_attention.layernorm_query.weight.T, True)
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_(
layer_te.inter_attention.layernorm_query.bias, True)
layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T,
True)
layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad
layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True)
layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias
layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias
layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True)
layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias
layer_te.inter_attention.proj.bias.stop_gradient = no_dbias
# LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
if output_layernorm:
layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
layer_pd.layernorm.weight.stop_gradient = no_wgrad
layer_pd.layernorm.bias.stop_gradient = no_dbias
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer, encoder_input, mask, encoder_output,
enc_dec_attn_mask, dout):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
out = layer(_encoder_input, mask, _encoder_output, enc_dec_attn_mask)
out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=0.1)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
atol=atol)
if not no_dbias:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5)
assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad,
layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol,
atol=atol)
...@@ -46,7 +46,7 @@ from transformer_engine.paddle.constants import FP8FwdTensors ...@@ -46,7 +46,7 @@ from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
np.random.seed(10) np.random.seed(10)
paddle.seed(10) paddle.seed(11)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -400,7 +400,7 @@ class TestLayerNorm: ...@@ -400,7 +400,7 @@ class TestLayerNorm:
y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta) y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta)
assert_allclose(y, y_ref, rtol=1e-5, atol=1e-5) assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4)
assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3) assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3)
assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2) assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2)
...@@ -725,10 +725,8 @@ class TestFusedAttn: ...@@ -725,10 +725,8 @@ class TestFusedAttn:
q_grad = dq q_grad = dq
k_grad = dkv[:, :, 0, :, :] k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :] v_grad = dkv[:, :, 1, :, :]
fwd_out = paddle.reshape(
out, shape=[self.batch_size, self.q_seqlen, self.num_heads, self.head_size])
return fwd_out, q_grad, k_grad, v_grad return out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU") reason="cuDNN fMHA requires Ampere+ GPU")
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for Paddle""" """Transformer Engine bindings for Paddle"""
from .layer import Linear, LayerNorm, LayerNormLinear, LayerNormMLP
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax,
DotProductAttention, MultiHeadAttention, TransformerLayer)
...@@ -40,3 +40,9 @@ TE_DType = { ...@@ -40,3 +40,9 @@ TE_DType = {
paddle.float16: tex.DType.kFloat16, paddle.float16: tex.DType.kFloat16,
paddle.bfloat16: tex.DType.kBFloat16, paddle.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
LayerTypes = ("encoder", "decoder")
...@@ -435,9 +435,9 @@ def fused_attn_fwd_qkvpacked( ...@@ -435,9 +435,9 @@ def fused_attn_fwd_qkvpacked(
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv." assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
if set_zero: if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype) out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
else: else:
out = paddle.empty(shape=[total_seqs, h, d], dtype=qkv.dtype) out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
if is_training: if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
...@@ -574,9 +574,9 @@ def fused_attn_fwd_kvpacked( ...@@ -574,9 +574,9 @@ def fused_attn_fwd_kvpacked(
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv." assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
if set_zero: if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype) out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else: else:
out = paddle.empty(shape=[total_seqs_q, h, d], dtype=q.dtype) out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training: if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Layer level Paddle APIs""" """Layer level Paddle APIs"""
from .attention import DotProductAttention, MultiHeadAttention
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .linear import Linear from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
from .transformer import TransformerLayer
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Attntion API"""
import math
import warnings
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
from transformer_engine.paddle.constants import (
AttnTypes,
TE_DType,
)
from transformer_engine.paddle.cpp_extensions import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
)
from transformer_engine.paddle.utils import (attention_mask_func, mask_to_cu_seqlens)
from .base import TransformerEngineBaseLayer
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input"""
@staticmethod
def forward(ctx, qkv, cu_seqlens, attn_bias, rng_state, max_seqlen, attn_scale, qkv_dtype,
dropout_p, set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training):
"""Forward function for FusedAttention with packed QKV input"""
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
qkv,
cu_seqlens,
rng_state,
is_training,
max_seqlen,
qkv_dtype,
attn_bias,
attn_scale,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
)
ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, aux_ctx_tensors)
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.set_zero = set_zero
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
return out
@staticmethod
def backward(ctx, d_out):
"""Backward function for FusedAttention with packed QKV input"""
qkv, out, cu_seqlens, rng_state, aux_ctx_tensors = ctx.saved_tensor()
dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out,
aux_ctx_tensors, ctx.max_seqlen, ctx.qkv_dtype,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.qkv_layout, ctx.attn_bias_type,
ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (dqkv, None, None)
# else, return (dqkv, dbias)
return (dqkv, None, rest[0], None)
class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed KV input"""
@staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, rng_state, max_seqlen_q,
max_seqlen_kv, attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type, is_training):
"""Forward function for FusedAttention with packed KV input"""
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv,
rng_state, is_training, max_seqlen_q,
max_seqlen_kv, qkv_dtype, attn_bias,
attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.set_zero = set_zero
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
return out
@staticmethod
def backward(ctx, d_out):
"""Backward function for FusedAttention with packed KV input"""
q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors = ctx.saved_tensor()
dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, aux_ctx_tensors, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.qkv_dtype, ctx.attn_scale,
ctx.dropout_p, ctx.set_zero, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dq, dkv
if ctx.attn_bias_type == "no_bias":
return (dq, dkv, None, None, None)
# else, return (dq, dkv, dbias)
return (dq, dkv, None, None, rest[0], None)
class DotProductAttention(paddle.nn.Layer):
"""Dot Product Attention Layer
Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
Parameters
----------
norm_factor : float
normalization factor for the attention scores.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
attention_type: {'self', 'cross'}, default = `self`
type of attention operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation.
"""
def __init__(self,
norm_factor: float,
attention_dropout: float = 0.1,
attn_mask_type: str = "causal",
attention_type: str = "self",
backend: str = 'transformer_engine') -> None:
super().__init__()
self.norm_factor = norm_factor
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.backend = backend
self.rng_state = paddle.zeros((2,), dtype='int64')
self.rng_state.persistable = True
if self.backend != 'transformer_engine':
self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
attention_mask_func,
backend=self.backend)
def forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
) -> paddle.Tensor:
"""
Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note::
For self attention, :attr:`query_layer` is the `[query, key, value]` tensor
stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`,
:attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`).
And :attr:`key_value_layer` is `None`.
For cross attention, :attr:`query_layer` is the `[query]` tensor, which must
be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`,
:attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor,
which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2,
:attr:`num_attention_heads`, :attr:`size_per_head`).
Parameters
----------
query_layer : paddle.Tensor
Query tensor.
key_value_layer : paddle.Tensor
Key tensor.
attention_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using attention.
core_attention_bias_type: str, default = `no_bias`
only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
set_zero: bool, defautl = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
if self.backend == 'transformer_engine':
return self._te_forward(query_layer, key_value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero)
if self.backend == 'paddle':
if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.")
return self._pd_forward(query_layer, key_value_layer, attention_mask)
raise AttributeError(f"Backend {self.backend} is not supported.")
def _te_forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
) -> paddle.Tensor:
gen_state = paddle.get_rng_state()[0].__getstate__()
self.rng_state[0], self.rng_state[1] = gen_state[1], gen_state[2] # [seed, offset]
if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] kv: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention"
max_seqlen = query_layer.shape[1]
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask)
qkv_dtype = TE_DType[query_layer.dtype]
qkv_layout = "qkv_interleaved"
output = FusedAttnFuncPackedQKV.apply(
query_layer,
cu_seqlens,
core_attention_bias,
self.rng_state,
max_seqlen,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
self.training,
)
elif self.attention_type == "cross":
# cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d]
assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
and key_value_layer.shape[2] == 2
), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \
"for dot product cross attention"
max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_value_layer.shape[1]
cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
qkv_dtype = TE_DType[query_layer.dtype]
qkv_layout = "kv_interleaved"
output = FusedAttnFuncPackedKV.apply(
query_layer,
key_value_layer,
cu_seqlens_q,
cu_seqlens_kv,
core_attention_bias,
self.rng_state,
max_seqlen_q,
max_seqlen_kv,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
self.training,
)
else:
raise ValueError("attention_type must be one of ['self', 'cross']")
return output
def _pd_forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
attention_mask: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] k: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention"
q = query_layer[:, :, 0]
k = query_layer[:, :, 1]
v = query_layer[:, :, 2]
elif self.attention_type == "cross":
# cross attention - q: [b, s, h, d] kv: [b, s, 2, h, d]
assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
and key_value_layer.shape[2] == 2
), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \
f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \
f"kv: {key_value_layer.shape}"
q = query_layer
k = key_value_layer[:, :, 0]
v = key_value_layer[:, :, 1]
q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)
if self.attention_dropout > 0:
attention_probs = F.dropout(
attention_probs,
self.attention_dropout,
training=self.training,
)
out = paddle.matmul(attention_probs, v)
out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d]
# out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
return out
class MultiHeadAttention(TransformerEngineBaseLayer):
"""Attention w/ QKV and Proj Gemms
Parameters
----------
hidden_size: int
hidden size of the model.
num_attention_heads: int
number of attention heads.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
layernorm_epsilon: float, default = 1e-5
epsilon to use in the layer norm operations.
weight_attr: Union[paddle.ParamAttr, None], default = `None`
paddle.ParamAttr object for the weight parameter.
bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
paddle.ParamAttr object for the bias parameter.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
params_dtype: Optional[paddle.dtype], default = `None`
data type for the weights and biases.
return_layernorm_output: bool, default = `False`
whether to return the output of the layernorm operation.
input_layernorm: bool, default = `False`
whether to apply layernorm to the input.
attention_type: {'self', 'cross'}, default = `self`
type of attention operation.
zero_centered_gamma: bool, default = `False`
whether to zero initialize the gamma of the layernorm operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
zero_centered_gamma: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.return_layernorm_output = return_layernorm_output
self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
self.weight_attr = weight_attr
self.bias_attr = bias_attr
self.attn_mask_type = attn_mask_type
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.backend = backend
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
eps=layernorm_epsilon,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
backend=self.backend,
)
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
self.weight_attr,
self.bias_attr,
backend=self.backend,
)
else: # cross attention
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
hidden_size,
eps=layernorm_epsilon,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
backend=self.backend,
)
else:
self.query_layer = Linear(
hidden_size,
hidden_size,
self.weight_attr,
self.bias_attr,
backend=self.backend,
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
self.weight_attr,
self.bias_attr,
backend=self.backend,
)
# Attention.
self.core_attention = DotProductAttention(
norm_factor,
attention_dropout,
attn_mask_type=attn_mask_type,
attention_type=self.attention_type,
backend=self.backend,
)
# Linear
self.proj = Linear(
hidden_size,
hidden_size,
self.weight_attr,
self.bias_attr,
backend=self.backend,
)
def forward(
self,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
) -> Tuple[Union[paddle.Tensor, None], ...]:
"""
MultiHeadAttention Layer.
Parameters
----------
hidden_states : paddle.Tensor
Input tensor.
attention_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using attention.
encoder_output : Optional[paddle.Tensor], default = `None`
Output of the encoder layer.
core_attention_bias_type: str, default = `no_bias`
only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
set_zero: bool, defautl = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
# hidden_states: [b, s_q, hidden_size]
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
if self.attention_type == "self":
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(hidden_states)
if self.return_layernorm_output:
mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_qkv_layer = layernorm_qkv_outputs
else:
mixed_qkv_layer = self.qkv(hidden_states)
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(
shape=[0, 0, 3, self.num_attention_heads, self.hidden_size_per_attention_head])
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
else: # cross attention
mixed_kv_layer = self.key_value(encoder_output)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(
shape=[0, 0, 2, self.num_attention_heads, self.hidden_size_per_attention_head])
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(hidden_states)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query_layer(hidden_states)
query_layer = query_layer.reshape(
shape=[0, 0, self.num_attention_heads, self.hidden_size_per_attention_head])
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
context_layer = paddle.reshape(context_layer,
[0, 0, context_layer.shape[2] * context_layer.shape[3]])
# Output. [b, s, hidden]
attention_output = self.proj(context_layer)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, layernorm_output
return attention_output
...@@ -126,7 +126,7 @@ class LayerNorm(paddle.nn.Layer): ...@@ -126,7 +126,7 @@ class LayerNorm(paddle.nn.Layer):
"Paddle backend does not support LayerNorm with zero-centered scale.") "Paddle backend does not support LayerNorm with zero-centered scale.")
return F.layer_norm(x=inp, return F.layer_norm(x=inp,
normalized_shape=inp.shape[1:], normalized_shape=inp.shape[-1],
weight=self.weight, weight=self.weight,
bias=self.bias, bias=self.bias,
epsilon=self.eps) epsilon=self.eps)
......
...@@ -402,7 +402,6 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -402,7 +402,6 @@ class LayerNormLinear(TransformerEngineBaseLayer):
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
return out, ln_out return out, ln_out
return out return out
def _pd_forward( def _pd_forward(
...@@ -415,7 +414,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -415,7 +414,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
"Paddle backend does not support LayerNorm with zero-centered scale.") "Paddle backend does not support LayerNorm with zero-centered scale.")
ln_out = F.layer_norm(x=inp, ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[1:], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
bias=self.ln_bias, bias=self.ln_bias,
epsilon=self.eps) epsilon=self.eps)
......
...@@ -624,7 +624,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -624,7 +624,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
"Paddle backend does not support LayerNorm with zero-centered scale.") "Paddle backend does not support LayerNorm with zero-centered scale.")
ln_out = F.layer_norm(x=inp, ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[1:], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
bias=self.ln_bias, bias=self.ln_bias,
epsilon=self.eps) epsilon=self.eps)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused scaled masked softmax functions"""
import os
import warnings
from typing import Callable, Tuple, Union, Optional
import paddle
from transformer_engine.paddle.cpp_extensions import (
scaled_upper_triang_masked_softmax_forward,
scaled_upper_triang_masked_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_backward,
scaled_softmax_forward,
scaled_softmax_backward,
)
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(seqlen: int) -> paddle.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if seqlen not in _default_causal_mask:
_default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)),
diagonal=1).cast('bool')
return _default_causal_mask[seqlen]
class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
scale_t = paddle.Tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensor()
input_grads = scaled_upper_triang_masked_softmax_backward(output_grads, softmax_results,
scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(paddle.autograd.PyLayer):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: paddle.Tensor, mask: paddle.Tensor, scale: float) -> paddle.Tensor:
"""ScaledMaskedSoftmax fwd"""
scale_t = paddle.Tensor([scale])
softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
"""ScaledMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensor()
input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class ScaledSoftmax(paddle.autograd.PyLayer):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor:
"""ScaledSoftmax fwd"""
scale_t = paddle.Tensor([scale])
softmax_results = scaled_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
"""ScaledSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensor()
input_grads = scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class FusedScaleMaskSoftmax(paddle.nn.Layer):
"""
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
"""
def __init__(
self,
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool = True,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")))
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.backend = backend
def forward(
self,
inp: paddle.Tensor,
mask: paddle.Tensor,
scale: Optional[float] = None,
) -> paddle.Tensor:
"""FusedScaleMaskSoftmax fprop"""
# [batch_size, num_heads, s_q, s_kv]
assert inp.dim() == 4
self.input_is_fp16 = inp.dtype == paddle.float16
self.input_is_bf16 = inp.dtype == paddle.bfloat16
self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16
assert (scale is None or self.softmax_in_fp32), "softmax should be in fp32 when scaled"
if self.backend == 'transformer_engine' and not self.is_kernel_available(*inp.shape):
warnings.warn(
"fused kernel is not available for this input shape, fall back to paddle backend")
self.backend = 'paddle'
if self.backend == 'transformer_engine':
return self._te_forward(inp, mask, scale)
if self.backend == 'paddle':
return self._pd_forward(inp, mask, scale)
raise AttributeError(f"Backend {self.backend} is not supported.")
def is_kernel_available(self, b: int, h: int, s_q: int, s_kv: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * h
if (self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_16bit_float # input must be fp16
and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048
and s_q % 4 == 0 # s_q must be a multiple of 4
and attn_batches % 4 == 0 # b * h must be a multiple of 4
):
if 0 <= s_kv <= 4096:
batch_per_block = self.get_batch_per_block(int(s_kv))
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
else:
if s_q % batch_per_block == 0:
return True
return False
def _te_forward(self,
inp: paddle.Tensor,
mask: paddle.Tensor,
scale: Optional[float] = None) -> paddle.Tensor:
"""Fused masked softmax kernel"""
b, h, s_q, s_kv = inp.size()
scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal":
assert s_q == s_kv, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, s_q, s_kv)
inp = inp.reshape((-1, s_q, s_kv))
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.reshape((b, h, s_q, s_kv))
# input is 4D tensor (b, h, s_q, s_kv)
if mask is not None:
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
def _pd_forward(self,
inp: paddle.Tensor,
mask: paddle.Tensor,
scale: Optional[float] = None) -> paddle.Tensor:
"""Call Paddle OP"""
if self.input_in_16bit_float and self.softmax_in_fp32:
inp = paddle.cast(inp, 'float32')
if scale is not None:
inp = inp * scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.shape[2])
mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = paddle.nn.functional.softmax(mask_output, axis=-1)
if self.input_in_16bit_float and self.softmax_in_fp32:
if self.input_is_fp16:
probs = paddle.cast(probs, 'float16')
else:
probs = paddle.cast(probs, 'bfloat16')
return probs
@staticmethod
def get_batch_per_block(key_seq_len: int) -> int:
"""Softmax utility"""
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK // warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer"""
from typing import Optional, Union
import paddle
from transformer_engine.paddle.constants import (
AttnMaskTypes,
LayerTypes,
)
from transformer_engine.paddle.layer import (LayerNormMLP, LayerNorm, MultiHeadAttention)
from .base import TransformerEngineBaseLayer
class TransformerLayer(TransformerEngineBaseLayer):
r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
apply_residual_connection_post_layernorm : bool, default = `False`
if set to `True`, residual connections are taken
from the output of layer norm (default is taken
from input of layer norm)
output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
params_dtype : paddle.dtype, default = `paddle.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
self_attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
layer_type: str = "encoder",
zero_centered_gamma: bool = False,
activation: str = 'gelu',
backend: str = 'transformer_engine') -> None:
super().__init__()
params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
self.output_layernorm = output_layernorm
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type
assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
attention_args = (
hidden_size,
num_attention_heads,
attention_dropout,
layernorm_epsilon,
weight_attr,
bias_attr,
)
common_attention_kwargs = {
"params_dtype": params_dtype,
"return_layernorm_output": apply_residual_connection_post_layernorm,
"zero_centered_gamma": zero_centered_gamma,
"backend": backend,
}
self.self_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm,
attention_type="self",
)
if layer_type == "decoder":
self.inter_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type="padding",
input_layernorm=True,
attention_type="cross",
)
self.layernorm_mlp = LayerNormMLP(
hidden_size,
ffn_hidden_size,
eps=layernorm_epsilon,
weight_attr=weight_attr,
bias_attr=bias_attr,
activation=activation,
return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma,
backend=backend,
)
self.hidden_dropout = hidden_dropout
if self.output_layernorm:
self.layernorm = LayerNorm(
hidden_size,
layernorm_epsilon,
weight_attr,
bias_attr,
zero_centered_gamma=zero_centered_gamma,
backend=backend,
)
def forward(
self,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None,
enc_dec_attn_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
) -> paddle.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Parameters
----------
hidden_states : paddle.Tensor
Input tensor.
attention_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
encoder_output : Optional[paddle.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
core_attention_bias_type: str, default = `no_bias`
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
set_zero: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
assert core_attention_bias_type in ['no_bias'], f"Only no_bias is supported currently, " \
f"but receive core_attention_bias_type = {core_attention_bias_type}"
# Self attention.
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, residual = self_attention_outputs
else:
attention_output = self_attention_outputs
residual = hidden_states
# dropoout add.
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
enc_dec_attn_mask,
encoder_output=encoder_output,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
if self.apply_residual_connection_post_layernorm:
attention_output, residual = inter_attention_outputs
else:
attention_output = inter_attention_outputs
residual = bda_output
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# MLP.
mlp_outputs = self.layernorm_mlp(bda_output)
if self.apply_residual_connection_post_layernorm:
mlp_output, residual = mlp_outputs
else:
mlp_output = mlp_outputs
residual = bda_output
# dropoout add.
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
output = residual + out
# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)
# output: [b, s, hidden]
return output
...@@ -52,3 +52,37 @@ def get_paddle_act_func(activation): ...@@ -52,3 +52,37 @@ def get_paddle_act_func(activation):
if activation not in funcs: if activation not in funcs:
raise "Activation type " + activation + " is not supported." raise "Activation type " + activation + " is not supported."
return funcs[activation] return funcs[activation]
def attention_mask_func(attention_scores: paddle.Tensor,
attention_mask: paddle.Tensor) -> paddle.Tensor:
"""Get attention mask"""
def _masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0)
return attention_scores
def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor:
"""Convert mask to cu_seqlens"""
assert 'bool' in str(mask.dtype), "mask must be bool dtype"
assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]"
q_actual_seqlens = paddle.sum(mask[:, :, :, 0] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison
q_cu_seqlens = paddle.cumsum(q_actual_seqlens)
q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0)
if not need_kv:
return q_cu_seqlens, None
kv_actual_seqlens = paddle.sum(mask[:, :, 0, :] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison
kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens)
kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0)
return q_cu_seqlens, kv_cu_seqlens
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}"
return numerator // denominator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment