Unverified Commit a7a1a070 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

Add operators for Paddle (#285)



* add more ops
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* add skipif
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix bug
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change on coding style
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent a83605df
......@@ -3,13 +3,16 @@
# See LICENSE for license information.
"""Test TE operators"""
import struct
import numpy as np
import pytest
import paddle
import paddle.nn.functional as F
from utils import assert_allclose, create_fp8_meta
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
......@@ -23,6 +26,19 @@ from transformer_engine.paddle.cpp_extensions import (
layernorm_fwd_fp8,
layernorm_fwd,
layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
scaled_upper_triang_masked_softmax_backward,
)
from transformer_engine.paddle.fp8 import is_fp8_available
......@@ -31,6 +47,10 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
def test_quantize_dequantize():
"""
......@@ -49,6 +69,25 @@ def test_quantize_dequantize():
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
"""
convert float to uint16
"""
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return new_output
class TestTranspose:
"""
Test transpose operators
......@@ -370,3 +409,391 @@ class TestLayerNorm:
assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
class TestRMSNorm:
"""
Test rmsnorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma):
"""
Calculate rmsnorm reference using paddle op
"""
norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
y = x * norm * gamma
return y
def calc_bwd_ref(self, x, eps, gamma, dy):
"""
Calculate rmsnorm bwd reference using paddle op
"""
x.stop_gradient = False
gamma.stop_gradient = False
y = self.calc_fwd_ref(x, eps, gamma)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad
def test_rmsnorm_fwd(self):
"""
Test BF16 RMSNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
y_ref = self.calc_fwd_ref(x, eps, gamma)
assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)
@staticmethod
def test_rmsnorm_fwd_fp8():
"""
Test FP8 RMSNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)
y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(rsigma, rsigma_ref)
def test_rmsnorm_bwd(self):
"""
Test BF16 RMSNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
_, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)
class TestFusedAttn:
"""
Test fused attention operators
"""
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
"""
set test input
"""
def _random(shape):
if self.dtype == "bfloat16":
data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
return convert_float_to_uint16(data)
return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)
self.batch_size = b
self.q_seqlen = s_q
self.kv_seqlen = s_kv
self.num_heads = h
self.head_size = d
self.dropout_prob = 0.0
self.scaling_factor = 1.0 / np.sqrt(d)
self.q_shape = (b, s_q, h, d)
self.kv_shape = (b, s_kv, h, d)
self.fuse_qkv_shape = (b, s_q, 3, h, d)
self.fuse_kv_shape = (b, s_kv, 2, h, d)
self.bias_shape = (1, h, s_q, s_kv)
self.attn_mode = attn_mode
self.dtype = dtype
self.is_causal_masking = is_causal_masking
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32,
)
for i in range(0, self.batch_size):
self.attn_mask[i, 0, 0:self.q_actual_seqlen[i], 0:self.kv_actual_seqlen[i],] = 1
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
col_beg, col_end = 1, self.q_actual_seqlen[i]
for row in range(0, self.q_actual_seqlen[i]):
self.attn_mask[i, 0, row, col_beg:col_end] = 0
col_beg += 1
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
def _get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul(
x=q_out * self.scaling_factor,
y=k_out,
transpose_x=False,
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
attn_mask = (paddle.cast(attn_mask, self.dtype) - 1.0) * 1e4
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
qkv_out = paddle.matmul(dropout_out, v_out)
else:
qkv_out = paddle.matmul(softmax_out, v_out)
out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward(
[out],
[self.dout],
retain_graph=True,
)
return out, q_tensor.grad, k_tensor.grad, v_tensor.grad
def _get_fused_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn":
qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
rng_state = paddle.zeros((2,), dtype=np.int64)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn':
out, softmax_aux_tensor = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor = fused_attn_fwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
fwd_out = paddle.reshape(
out, shape=[self.batch_size, self.q_seqlen, self.num_heads, self.head_size])
return fwd_out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
"""
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
Test softmax operators
"""
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_softmax_fwd_bwd(self, dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
y_ref = F.softmax(scale * x)
y = scaled_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_masked_softmax_fwd_bwd(self, dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
mask_flipped = x[0, 0] <= 0.3
mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_masked_softmax_forward(x, mask, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(self, dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
scale = 0.8
x = paddle.uniform(shape=(B, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype='int32')
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
col_beg += 1
mask_ref = (mask.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_upper_triang_masked_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
import math
from typing import Optional, Tuple, Union
import paddle
import transformer_engine_paddle as tex
......@@ -338,3 +339,411 @@ def layernorm_bwd(
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm backward"""
return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
def rmsnorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)
def rmsnorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, int(fp8_tensor),
int(otype), sm_margin)
return out, rsigma
def rmsnorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
def fused_attn_fwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen, max_seqlen
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding/causal attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype)
else:
out = paddle.empty(shape=[total_seqs, h, d], dtype=qkv.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_qkvpacked(
qkv,
cu_seqlens,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs,
max_seqlen,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = None
# execute kernel
dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
qkv,
cu_seqlens,
o,
d_o,
softmax_aux,
dqkv,
dbias,
b,
h,
d,
total_seqs,
max_seqlen,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[total_seqs_q, h, d], dtype=q.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dkv,
dbias,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dkv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax forward"""
return tex.te_scaled_softmax_forward(inp, scale_factor)
def scaled_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_masked_softmax_forward(
inp: paddle.Tensor,
mask: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax forward"""
return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)
def scaled_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_upper_triang_masked_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax forward"""
return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)
def scaled_upper_triang_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax backward"""
tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
......@@ -9,8 +9,9 @@
namespace transformer_engine {
namespace paddle_ext {
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type) {
return TensorWrapper(data_ptr, shape, type);
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type) {
return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
......
......@@ -8,9 +8,12 @@
#include <cublasLt.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <vector>
......@@ -78,11 +81,6 @@ inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLIN
return x ? x->data() : nullptr;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
std::vector<size_t> shapes;
for (auto dim : x.shape()) {
......@@ -91,6 +89,11 @@ inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
return shapes;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros = 0);
......@@ -176,7 +179,8 @@ class cudaDevicePropertiesManager {
};
// NVTE Tensor Utils
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type);
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment