"vscode:/vscode.git/clone" did not exist on "8e7d3bc8b6367a94868a84459ddefd66b06b29e9"
Unverified Commit a7a1a070 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

Add operators for Paddle (#285)



* add more ops
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* add skipif
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix bug
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change on coding style
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent a83605df
...@@ -3,13 +3,16 @@ ...@@ -3,13 +3,16 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Test TE operators""" """Test TE operators"""
import struct
import numpy as np
import pytest import pytest
import paddle import paddle
import paddle.nn.functional as F
from utils import assert_allclose, create_fp8_meta from utils import assert_allclose, create_fp8_meta
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.cpp_extensions import ( from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
...@@ -23,6 +26,19 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -23,6 +26,19 @@ from transformer_engine.paddle.cpp_extensions import (
layernorm_fwd_fp8, layernorm_fwd_fp8,
layernorm_fwd, layernorm_fwd,
layernorm_bwd, layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
scaled_upper_triang_masked_softmax_backward,
) )
from transformer_engine.paddle.fp8 import is_fp8_available from transformer_engine.paddle.fp8 import is_fp8_available
...@@ -31,6 +47,10 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, ...@@ -31,6 +47,10 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
def test_quantize_dequantize(): def test_quantize_dequantize():
""" """
...@@ -49,6 +69,25 @@ def test_quantize_dequantize(): ...@@ -49,6 +69,25 @@ def test_quantize_dequantize():
assert_allclose(a, b, rtol=5e-2, atol=5e-2) assert_allclose(a, b, rtol=5e-2, atol=5e-2)
def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
"""
convert float to uint16
"""
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return new_output
class TestTranspose: class TestTranspose:
""" """
Test transpose operators Test transpose operators
...@@ -370,3 +409,391 @@ class TestLayerNorm: ...@@ -370,3 +409,391 @@ class TestLayerNorm:
assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5) assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5) assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5) assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
class TestRMSNorm:
"""
Test rmsnorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma):
"""
Calculate rmsnorm reference using paddle op
"""
norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
y = x * norm * gamma
return y
def calc_bwd_ref(self, x, eps, gamma, dy):
"""
Calculate rmsnorm bwd reference using paddle op
"""
x.stop_gradient = False
gamma.stop_gradient = False
y = self.calc_fwd_ref(x, eps, gamma)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad
def test_rmsnorm_fwd(self):
"""
Test BF16 RMSNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
y_ref = self.calc_fwd_ref(x, eps, gamma)
assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)
@staticmethod
def test_rmsnorm_fwd_fp8():
"""
Test FP8 RMSNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)
y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(rsigma, rsigma_ref)
def test_rmsnorm_bwd(self):
"""
Test BF16 RMSNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
_, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)
class TestFusedAttn:
"""
Test fused attention operators
"""
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
"""
set test input
"""
def _random(shape):
if self.dtype == "bfloat16":
data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
return convert_float_to_uint16(data)
return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)
self.batch_size = b
self.q_seqlen = s_q
self.kv_seqlen = s_kv
self.num_heads = h
self.head_size = d
self.dropout_prob = 0.0
self.scaling_factor = 1.0 / np.sqrt(d)
self.q_shape = (b, s_q, h, d)
self.kv_shape = (b, s_kv, h, d)
self.fuse_qkv_shape = (b, s_q, 3, h, d)
self.fuse_kv_shape = (b, s_kv, 2, h, d)
self.bias_shape = (1, h, s_q, s_kv)
self.attn_mode = attn_mode
self.dtype = dtype
self.is_causal_masking = is_causal_masking
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32,
)
for i in range(0, self.batch_size):
self.attn_mask[i, 0, 0:self.q_actual_seqlen[i], 0:self.kv_actual_seqlen[i],] = 1
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
col_beg, col_end = 1, self.q_actual_seqlen[i]
for row in range(0, self.q_actual_seqlen[i]):
self.attn_mask[i, 0, row, col_beg:col_end] = 0
col_beg += 1
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
def _get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul(
x=q_out * self.scaling_factor,
y=k_out,
transpose_x=False,
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
attn_mask = (paddle.cast(attn_mask, self.dtype) - 1.0) * 1e4
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
qkv_out = paddle.matmul(dropout_out, v_out)
else:
qkv_out = paddle.matmul(softmax_out, v_out)
out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward(
[out],
[self.dout],
retain_graph=True,
)
return out, q_tensor.grad, k_tensor.grad, v_tensor.grad
def _get_fused_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn":
qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
rng_state = paddle.zeros((2,), dtype=np.int64)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn':
out, softmax_aux_tensor = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor = fused_attn_fwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
fwd_out = paddle.reshape(
out, shape=[self.batch_size, self.q_seqlen, self.num_heads, self.head_size])
return fwd_out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
"""
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
Test softmax operators
"""
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_softmax_fwd_bwd(self, dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
y_ref = F.softmax(scale * x)
y = scaled_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_masked_softmax_fwd_bwd(self, dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
mask_flipped = x[0, 0] <= 0.3
mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_masked_softmax_forward(x, mask, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(self, dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
scale = 0.8
x = paddle.uniform(shape=(B, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype='int32')
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
col_beg += 1
mask_ref = (mask.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_upper_triang_masked_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""TE FP8 extensions and GEMMs""" """TE FP8 extensions and GEMMs"""
import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import paddle import paddle
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
...@@ -338,3 +339,411 @@ def layernorm_bwd( ...@@ -338,3 +339,411 @@ def layernorm_bwd(
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm backward""" """Non-FP8 LayerNorm backward"""
return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
def rmsnorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)
def rmsnorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, int(fp8_tensor),
int(otype), sm_margin)
return out, rsigma
def rmsnorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
def fused_attn_fwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen, max_seqlen
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding/causal attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype)
else:
out = paddle.empty(shape=[total_seqs, h, d], dtype=qkv.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_qkvpacked(
qkv,
cu_seqlens,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs,
max_seqlen,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = None
# execute kernel
dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
qkv,
cu_seqlens,
o,
d_o,
softmax_aux,
dqkv,
dbias,
b,
h,
d,
total_seqs,
max_seqlen,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[total_seqs_q, h, d], dtype=q.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dkv,
dbias,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dkv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax forward"""
return tex.te_scaled_softmax_forward(inp, scale_factor)
def scaled_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_masked_softmax_forward(
inp: paddle.Tensor,
mask: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax forward"""
return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)
def scaled_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_upper_triang_masked_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax forward"""
return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)
def scaled_upper_triang_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax backward"""
tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type) { TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
return TensorWrapper(data_ptr, shape, type); const DType type) {
return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
} }
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) { TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
......
...@@ -8,9 +8,12 @@ ...@@ -8,9 +8,12 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h> #include <transformer_engine/layer_norm.h>
#include <transformer_engine/logging.h> #include <transformer_engine/logging.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <vector> #include <vector>
...@@ -78,11 +81,6 @@ inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLIN ...@@ -78,11 +81,6 @@ inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLIN
return x ? x->data() : nullptr; return x ? x->data() : nullptr;
} }
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) { inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
std::vector<size_t> shapes; std::vector<size_t> shapes;
for (auto dim : x.shape()) { for (auto dim : x.shape()) {
...@@ -91,6 +89,11 @@ inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) { ...@@ -91,6 +89,11 @@ inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
return shapes; return shapes;
} }
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros = 0); bool init_to_zeros = 0);
...@@ -176,7 +179,8 @@ class cudaDevicePropertiesManager { ...@@ -176,7 +179,8 @@ class cudaDevicePropertiesManager {
}; };
// NVTE Tensor Utils // NVTE Tensor Utils
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type); TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type); TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type, TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr); void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment