Unverified Commit a7a1a070 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

Add operators for Paddle (#285)



* add more ops
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* add skipif
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix bug
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor change on coding style
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent a83605df
......@@ -3,13 +3,16 @@
# See LICENSE for license information.
"""Test TE operators"""
import struct
import numpy as np
import pytest
import paddle
import paddle.nn.functional as F
from utils import assert_allclose, create_fp8_meta
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
......@@ -23,6 +26,19 @@ from transformer_engine.paddle.cpp_extensions import (
layernorm_fwd_fp8,
layernorm_fwd,
layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
scaled_upper_triang_masked_softmax_backward,
)
from transformer_engine.paddle.fp8 import is_fp8_available
......@@ -31,6 +47,10 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
def test_quantize_dequantize():
"""
......@@ -49,6 +69,25 @@ def test_quantize_dequantize():
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
"""
convert float to uint16
"""
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return new_output
class TestTranspose:
"""
Test transpose operators
......@@ -370,3 +409,391 @@ class TestLayerNorm:
assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
class TestRMSNorm:
"""
Test rmsnorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma):
"""
Calculate rmsnorm reference using paddle op
"""
norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
y = x * norm * gamma
return y
def calc_bwd_ref(self, x, eps, gamma, dy):
"""
Calculate rmsnorm bwd reference using paddle op
"""
x.stop_gradient = False
gamma.stop_gradient = False
y = self.calc_fwd_ref(x, eps, gamma)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad
def test_rmsnorm_fwd(self):
"""
Test BF16 RMSNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
y_ref = self.calc_fwd_ref(x, eps, gamma)
assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)
@staticmethod
def test_rmsnorm_fwd_fp8():
"""
Test FP8 RMSNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)
y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(rsigma, rsigma_ref)
def test_rmsnorm_bwd(self):
"""
Test BF16 RMSNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
_, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)
class TestFusedAttn:
"""
Test fused attention operators
"""
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
"""
set test input
"""
def _random(shape):
if self.dtype == "bfloat16":
data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
return convert_float_to_uint16(data)
return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)
self.batch_size = b
self.q_seqlen = s_q
self.kv_seqlen = s_kv
self.num_heads = h
self.head_size = d
self.dropout_prob = 0.0
self.scaling_factor = 1.0 / np.sqrt(d)
self.q_shape = (b, s_q, h, d)
self.kv_shape = (b, s_kv, h, d)
self.fuse_qkv_shape = (b, s_q, 3, h, d)
self.fuse_kv_shape = (b, s_kv, 2, h, d)
self.bias_shape = (1, h, s_q, s_kv)
self.attn_mode = attn_mode
self.dtype = dtype
self.is_causal_masking = is_causal_masking
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32,
)
for i in range(0, self.batch_size):
self.attn_mask[i, 0, 0:self.q_actual_seqlen[i], 0:self.kv_actual_seqlen[i],] = 1
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
col_beg, col_end = 1, self.q_actual_seqlen[i]
for row in range(0, self.q_actual_seqlen[i]):
self.attn_mask[i, 0, row, col_beg:col_end] = 0
col_beg += 1
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
def _get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul(
x=q_out * self.scaling_factor,
y=k_out,
transpose_x=False,
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
attn_mask = (paddle.cast(attn_mask, self.dtype) - 1.0) * 1e4
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
qkv_out = paddle.matmul(dropout_out, v_out)
else:
qkv_out = paddle.matmul(softmax_out, v_out)
out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward(
[out],
[self.dout],
retain_graph=True,
)
return out, q_tensor.grad, k_tensor.grad, v_tensor.grad
def _get_fused_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn":
qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
rng_state = paddle.zeros((2,), dtype=np.int64)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn':
out, softmax_aux_tensor = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor = fused_attn_fwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
out,
self.dout,
softmax_aux_tensor,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
fwd_out = paddle.reshape(
out, shape=[self.batch_size, self.q_seqlen, self.num_heads, self.head_size])
return fwd_out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
"""
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
Test softmax operators
"""
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_softmax_fwd_bwd(self, dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
y_ref = F.softmax(scale * x)
y = scaled_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_masked_softmax_fwd_bwd(self, dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
mask_flipped = x[0, 0] <= 0.3
mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_masked_softmax_forward(x, mask, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(self, dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
scale = 0.8
x = paddle.uniform(shape=(B, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype='int32')
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
col_beg += 1
mask_ref = (mask.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_upper_triang_masked_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
import math
from typing import Optional, Tuple, Union
import paddle
import transformer_engine_paddle as tex
......@@ -338,3 +339,411 @@ def layernorm_bwd(
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm backward"""
return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
def rmsnorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)
def rmsnorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, int(fp8_tensor),
int(otype), sm_margin)
return out, rsigma
def rmsnorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
def fused_attn_fwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen, max_seqlen
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding/causal attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype)
else:
out = paddle.empty(shape=[total_seqs, h, d], dtype=qkv.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_qkvpacked(
qkv,
cu_seqlens,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs,
max_seqlen,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = None
# execute kernel
dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
qkv,
cu_seqlens,
o,
d_o,
softmax_aux,
dqkv,
dbias,
b,
h,
d,
total_seqs,
max_seqlen,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[total_seqs_q, h, d], dtype=q.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
softmax_aux = None
# execute kernel
tex.te_fused_attn_fwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return out, softmax_aux
def fused_attn_bwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "kv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dkv,
dbias,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dkv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax forward"""
return tex.te_scaled_softmax_forward(inp, scale_factor)
def scaled_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_masked_softmax_forward(
inp: paddle.Tensor,
mask: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax forward"""
return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)
def scaled_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_upper_triang_masked_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax forward"""
return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)
def scaled_upper_triang_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax backward"""
tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
......@@ -9,8 +9,9 @@
namespace transformer_engine {
namespace paddle_ext {
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type) {
return TensorWrapper(data_ptr, shape, type);
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type) {
return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
......
......@@ -8,9 +8,12 @@
#include <cublasLt.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <vector>
......@@ -78,11 +81,6 @@ inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLIN
return x ? x->data() : nullptr;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
std::vector<size_t> shapes;
for (auto dim : x.shape()) {
......@@ -91,6 +89,11 @@ inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
return shapes;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros = 0);
......@@ -176,7 +179,8 @@ class cudaDevicePropertiesManager {
};
// NVTE Tensor Utils
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type);
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
......
......@@ -5,10 +5,52 @@
************************************************************************/
#include <vector>
#include "../common.h"
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) {
if (qkv_layout == "not_interleaved") {
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED;
} else if (qkv_layout == "qkv_interleaved") {
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
} else if (qkv_layout == "kv_interleaved") {
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
} else {
NVTE_ERROR("Invalid QKV layout. \n");
}
}
// convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
}
std::vector<paddle::Tensor> cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
......@@ -337,6 +379,594 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
return {dx, dgamma, dbeta};
}
std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
return {ln_out, rsigma};
}
std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(
ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
return {ln_out, rsigma};
}
std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin) {
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace, barrier, dgamma_part;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
auto rsigma_cu = MakeNvteTensor(rsigma);
auto gamma_cu = MakeNvteTensor(gamma);
auto dx_cu = MakeNvteTensor(dx);
auto dgamma_cu = MakeNvteTensor(dgamma);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates tensors with the required config.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true);
auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype());
// Actual call to bwd kernel.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin,
workspace.data(), barrier.data());
return {dx, dgamma};
}
void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, bool is_training, float attn_scale,
float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
const int64_t qkv_type) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_QKV = MakeNvteTensor(QKV);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
// fused attention BWD with packed QKV
void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, int64_t qkv_type) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_QKV = MakeNvteTensor(QKV);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQKV = MakeNvteTensor(dQKV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen), static_cast<size_t>(max_seqlen)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_fwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_seqlens_kv,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(
Q.data(),
{static_cast<size_t>(total_seqs_q), static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
te_KV = MakeNvteTensor(
KV.data(),
{static_cast<size_t>(total_seqs_kv), 2, static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(
O.data(),
{static_cast<size_t>(total_seqs_q), static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
// fused attention BWD with packed KV
void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &O,
const paddle::Tensor &dO, const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_KV = MakeNvteTensor(KV);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dKV = MakeNvteTensor(dKV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
const int attn_heads = input.shape()[1];
const int query_seq_len = input.shape()[2];
const int key_seq_len = input.shape()[3];
NVTE_CHECK(key_seq_len <= 4096);
NVTE_CHECK(query_seq_len > 1);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
input.stream());
return {softmax_results};
}
void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results, float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(),
output_grads_cu.data(), scale_factor, softmax_results.stream());
}
std::vector<paddle::Tensor> te_scaled_masked_softmax_forward(const paddle::Tensor &input,
const paddle::Tensor &mask,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
const int pad_batches = mask.shape()[0];
const int attn_heads = input.shape()[1];
const int query_seq_len = input.shape()[2];
const int key_seq_len = input.shape()[3];
NVTE_CHECK(key_seq_len <= 4096);
NVTE_CHECK(query_seq_len > 1);
NVTE_CHECK(pad_batches == 1 || pad_batches == batches);
NVTE_CHECK(mask.shape()[1] == 1);
NVTE_CHECK(mask.shape()[2] == query_seq_len);
NVTE_CHECK(mask.shape()[3] == key_seq_len);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto mask_cu = MakeNvteTensor(mask);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, input.stream());
return {softmax_results};
}
void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results, float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(),
output_grads_cu.data(), scale_factor, softmax_results.stream());
}
std::vector<paddle::Tensor> te_scaled_upper_triang_masked_softmax_forward(
const paddle::Tensor &input, float scale_factor) {
NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.shape()[0];
const int seq_len = input.shape()[1];
NVTE_CHECK(seq_len <= 2048);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(),
scale_factor, input.stream());
return {softmax_results};
}
void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results,
float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]);
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor,
softmax_results.stream());
}
} // namespace paddle_ext
} // namespace transformer_engine
......@@ -422,3 +1052,108 @@ PD_BUILD_OP(te_layernorm_bwd)
.Outputs({"Dx", "Dgamma", "Dbeta"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd));
PD_BUILD_OP(te_rmsnorm_fwd)
.Inputs({"Input", "Weight"})
.Outputs({"Output", "InvVariance"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd));
PD_BUILD_OP(te_rmsnorm_fwd_fp8)
.Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "InvVariance", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8));
PD_BUILD_OP(te_rmsnorm_bwd)
.Inputs({"Dz", "X", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma"})
.Attrs({"sm_margin: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd));
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"),
"rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias")})
.Outputs({"dQKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"})
.SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked));
PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV",
paddle::Optional("_dBias")})
.Outputs({"dQ", "dKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dKV", "dKV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked));
PD_BUILD_OP(te_scaled_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_forward));
PD_BUILD_OP(te_scaled_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_backward));
PD_BUILD_OP(te_scaled_masked_softmax_forward)
.Inputs({"input", "mask"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_forward));
PD_BUILD_OP(te_scaled_masked_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_backward));
PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_forward));
PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment