Unverified Commit 50ff8116 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX/Paddle] Deprecate QKV_INTERLEAVED enum (#504)



* Deprecate QKV_INTERLEAVED use in JAX
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Deprecate QKV_INTERLEAVED use in Paddle
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance qkv enum mappings
Signed-off-by: default avatarrewang <rewang@nvidia.com>

* Fix LD_LIBRARY_PATH issue
Signed-off-by: default avatarrewang <rewang@nvidia.com>

* Arbitrary seqlen kernels only support self attention currently
Signed-off-by: default avatarrewang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarrewang <rewang@nvidia.com>
parent 8ec01e5e
...@@ -18,7 +18,7 @@ from flax.linen import make_attention_mask ...@@ -18,7 +18,7 @@ from flax.linen import make_attention_mask
from flax.linen import make_causal_mask from flax.linen import make_causal_mask
from jax import value_and_grad, jit from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability
...@@ -163,8 +163,8 @@ class TestSelfFusedAttn(): ...@@ -163,8 +163,8 @@ class TestSelfFusedAttn():
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0: if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.") pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
dropout_probability, s, s, head_dim): attn_mask_type, dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
compute_capability = get_device_compute_capability(0) compute_capability = get_device_compute_capability(0)
......
...@@ -639,7 +639,7 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -639,7 +639,7 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="qkv_interleaved" if attn_type == "self" else "kv_interleaved", qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
...@@ -767,7 +767,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -767,7 +767,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="qkv_interleaved", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
...@@ -945,7 +945,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -945,7 +945,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="qkv_interleaved", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
...@@ -956,7 +956,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -956,7 +956,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="kv_interleaved", qkv_layout="bshd_bs2hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
......
...@@ -683,9 +683,9 @@ class TestFusedAttn: ...@@ -683,9 +683,9 @@ class TestFusedAttn:
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = ( qkv_layout = (
"qkv_interleaved" "bs3hd"
if self.attn_mode == "self_attn" if self.attn_mode == "self_attn"
else "kv_interleaved" else "bshd_bs2hd"
) )
fused_attention_backend = get_fused_attention_backend( fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size, head_size=self.head_size,
...@@ -779,7 +779,7 @@ class TestFusedAttn: ...@@ -779,7 +779,7 @@ class TestFusedAttn:
kv_seqlen=s, kv_seqlen=s,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="qkv_interleaved", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding", mask_type="causal" if is_causal_masking else "padding",
): ):
...@@ -804,7 +804,7 @@ class TestFusedAttn: ...@@ -804,7 +804,7 @@ class TestFusedAttn:
kv_seqlen=s_kv, kv_seqlen=s_kv,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="kv_interleaved", qkv_layout="bshd_bs2hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="padding", mask_type="padding",
): ):
...@@ -830,7 +830,7 @@ class TestFusedAttn: ...@@ -830,7 +830,7 @@ class TestFusedAttn:
kv_seqlen=s, kv_seqlen=s,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="qkv_interleaved", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding", mask_type="causal" if is_causal_masking else "padding",
): ):
......
...@@ -14,13 +14,12 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker ...@@ -14,13 +14,12 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import ( from transformer_engine.paddle.constants import (
TE_DType, TE_DType,
QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.paddle.fp8 import FP8TensorMeta from transformer_engine.paddle.fp8 import FP8TensorMeta
import transformer_engine_paddle as tex import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10): def create_fp8_meta(num_gemms=1, amax_history_len=10):
...@@ -101,13 +100,14 @@ def set_random_seed(seed): ...@@ -101,13 +100,14 @@ def set_random_seed(seed):
paddle.seed(global_seed) paddle.seed(global_seed)
def get_fused_attention_backend( def get_fused_attention_backend(
head_size: int, head_size: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "bs3hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
mask_type: str = "causal", mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend: ) -> tex.NVTE_Fused_Attn_Backend:
...@@ -121,7 +121,7 @@ def get_fused_attention_backend( ...@@ -121,7 +121,7 @@ def get_fused_attention_backend(
return tex.get_fused_attn_backend( return tex.get_fused_attn_backend(
TE_DType[dtype], TE_DType[dtype],
TE_DType[dtype], TE_DType[dtype],
QKVLayout[qkv_layout], tex.get_nvte_qkv_layout(qkv_layout),
AttnBiasType[bias_type], AttnBiasType[bias_type],
AttnMaskType[mask_type], AttnMaskType[mask_type],
dropout, dropout,
...@@ -130,13 +130,14 @@ def get_fused_attention_backend( ...@@ -130,13 +130,14 @@ def get_fused_attention_backend(
head_size, head_size,
) )
def is_fused_attention_supported( def is_fused_attention_supported(
head_size: int, head_size: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "bs3hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
mask_type: str = "causal", mask_type: str = "causal",
) -> bool: ) -> bool:
......
...@@ -133,8 +133,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -133,8 +133,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { || (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
......
...@@ -83,6 +83,7 @@ class FusedAttnHelper: ...@@ -83,6 +83,7 @@ class FusedAttnHelper:
q_type: jnp.dtype q_type: jnp.dtype
kv_type: jnp.dtype kv_type: jnp.dtype
qkv_layout: NVTE_QKV_Layout
attn_bias_type: NVTE_Bias_Type attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type attn_mask_type: NVTE_Mask_Type
dropout_probability: float dropout_probability: float
...@@ -96,10 +97,13 @@ class FusedAttnHelper: ...@@ -96,10 +97,13 @@ class FusedAttnHelper:
def get_fused_attn_backend(self): def get_fused_attn_backend(self):
"""Get the fused attention kernel backend""" """Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend( return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type), jax_dtype_to_te_dtype(self.kv_type),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, self.attn_bias_type, self.attn_mask_type, self.qkv_layout, self.attn_bias_type,
self.dropout_probability, self.max_seqlen_q, self.max_seqlen_kv, self.head_dim) self.attn_mask_type,
self.dropout_probability,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim)
def merge_named_shape(base, new): def merge_named_shape(base, new):
...@@ -210,7 +214,8 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): ...@@ -210,7 +214,8 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
# Need to disable one pylint error as the second function # Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be # parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version. # compatible with multiple JAX version.
out = custom_call(name, # pylint: disable=too-many-function-args out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types, args.output_types,
operands=args.operands, operands=args.operands,
operand_layouts=args.operand_layouts, operand_layouts=args.operand_layouts,
...@@ -2103,8 +2108,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2103,8 +2108,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
output_shape = (batch, max_seqlen, num_head, head_dim) output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype output_dtype = qkv_dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, attn_bias_type, attn_mask_type, backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
dropout_probability, max_seqlen, max_seqlen, attn_mask_type, dropout_probability, max_seqlen, max_seqlen,
head_dim).get_fused_attn_backend() head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
......
...@@ -86,9 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -86,9 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD);
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
...@@ -762,7 +762,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -762,7 +762,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -845,7 +845,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -845,7 +845,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -929,7 +929,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -929,7 +929,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -1064,9 +1064,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1064,9 +1064,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, descriptor.bias_type, descriptor.mask_type, query_workspace_tensor.data(), stream);
query_workspace_tensor.data(), stream);
size_t workspace_size = size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
...@@ -1081,9 +1080,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1081,9 +1080,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, descriptor.bias_type, descriptor.mask_type, workspace_tensor.data(), stream);
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
......
...@@ -23,7 +23,7 @@ from jax import lax, vmap ...@@ -23,7 +23,7 @@ from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
...@@ -428,6 +428,8 @@ class MultiHeadAttention(nn.Module): ...@@ -428,6 +428,8 @@ class MultiHeadAttention(nn.Module):
raise ValueError(f"Unsupported {attn_mask_type=}, " raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}") "supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype) canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
...@@ -441,7 +443,7 @@ class MultiHeadAttention(nn.Module): ...@@ -441,7 +443,7 @@ class MultiHeadAttention(nn.Module):
def _check_head_dim(head_dim): def _check_head_dim(head_dim):
return head_dim in [64, 128] return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type,
self.dropout_rate, q_seqlen, self.dropout_rate, q_seqlen,
kv_seqlen, self.head_dim) kv_seqlen, self.head_dim)
...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): ...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module):
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if inputs_q is inputs_kv: if is_self_attn:
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -571,7 +573,7 @@ class MultiHeadAttention(nn.Module): ...@@ -571,7 +573,7 @@ class MultiHeadAttention(nn.Module):
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
if inputs_q is inputs_kv: if is_self_attn:
assert ln_out is not None assert ln_out is not None
inputs_kv = ln_out inputs_kv = ln_out
...@@ -650,7 +652,7 @@ class MultiHeadAttention(nn.Module): ...@@ -650,7 +652,7 @@ class MultiHeadAttention(nn.Module):
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
if inputs_q is inputs_kv: if is_self_attn:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
......
...@@ -10,6 +10,7 @@ import jax.numpy as jnp ...@@ -10,6 +10,7 @@ import jax.numpy as jnp
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
...@@ -36,13 +37,19 @@ class AttnMaskType(Enum): ...@@ -36,13 +37,19 @@ class AttnMaskType(Enum):
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
def is_fused_attn_kernel_available(q_type, kv_type, attn_bias_type, attn_mask_type, class QKVLayout(Enum):
"""QKV layout"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim): dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is available
""" """
return FusedAttnHelper(q_type, kv_type, attn_bias_type.value, attn_mask_type.value, return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
dropout_probability, max_seqlen_q, max_seqlen_kv, attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv,
head_dim).is_fused_attn_kernel_available() head_dim).is_fused_attn_kernel_available()
......
...@@ -53,12 +53,6 @@ dist_group_type = paddle.distributed.collective.Group ...@@ -53,12 +53,6 @@ dist_group_type = paddle.distributed.collective.Group
RecomputeFunctionNames = ('unpack', 'backward') RecomputeFunctionNames = ('unpack', 'backward')
QKVLayout = {
"not_interleaved": tex.NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": tex.NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": tex.NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
}
AttnBiasType = { AttnBiasType = {
"no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, "pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
......
...@@ -431,7 +431,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -431,7 +431,7 @@ def fused_attn_fwd_qkvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, set_zero: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "bs3hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
...@@ -518,7 +518,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -518,7 +518,7 @@ def fused_attn_bwd_qkvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, set_zero: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "bs3hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
...@@ -587,7 +587,7 @@ def fused_attn_fwd_kvpacked( ...@@ -587,7 +587,7 @@ def fused_attn_fwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, set_zero: bool = True,
qkv_layout: str = "kv_interleaved", qkv_layout: str = "bshd_bs2hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
...@@ -685,7 +685,7 @@ def fused_attn_bwd_kvpacked( ...@@ -685,7 +685,7 @@ def fused_attn_bwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, set_zero: bool = True,
qkv_layout: str = "kv_interleaved", qkv_layout: str = "bshd_bs2hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
......
...@@ -53,5 +53,34 @@ paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const pad ...@@ -53,5 +53,34 @@ paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const pad
NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); NVTE_CHECK(false, "Should never reach here! func: AllocateSpace");
} }
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) {
static const std::unordered_map<std::string, NVTE_QKV_Layout> layout_map = {
{"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD},
{"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D},
{"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD},
{"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D},
{"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD},
{"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD},
{"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D},
{"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD},
{"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D},
{"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD},
{"t3hd", NVTE_QKV_Layout::NVTE_T3HD},
{"th3d", NVTE_QKV_Layout::NVTE_TH3D},
{"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD},
{"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D},
{"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD},
};
auto it = layout_map.find(qkv_layout);
if (it != layout_map.end()) {
return it->second;
} else {
NVTE_ERROR("Invalid QKV layout string: " + qkv_layout);
}
}
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "paddle/extension.h" #include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "common/util/logging.h"
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
...@@ -22,6 +21,7 @@ ...@@ -22,6 +21,7 @@
#include <transformer_engine/softmax.h> #include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
...@@ -177,5 +177,7 @@ TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, c ...@@ -177,5 +177,7 @@ TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, c
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout);
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <map>
#include <vector> #include <vector>
#include "common.h" #include "common.h"
...@@ -13,20 +14,6 @@ ...@@ -13,20 +14,6 @@
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) {
if (qkv_layout == "not_interleaved") {
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED;
} else if (qkv_layout == "qkv_interleaved") {
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
} else if (qkv_layout == "kv_interleaved") {
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
} else {
NVTE_ERROR("Invalid QKV layout. \n");
}
}
// convert bias type to enum // convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") { if (bias_type == "no_bias") {
......
...@@ -15,6 +15,7 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { ...@@ -15,6 +15,7 @@ PYBIND11_MODULE(transformer_engine_paddle, m) {
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string");
// Data structures // Data structures
py::enum_<DType>(m, "DType", py::module_local()) py::enum_<DType>(m, "DType", py::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -36,9 +37,21 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { ...@@ -36,9 +37,21 @@ PYBIND11_MODULE(transformer_engine_paddle, m) {
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", py::module_local()) py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", py::module_local())
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
...@@ -15,8 +15,8 @@ import transformer_engine_paddle as tex ...@@ -15,8 +15,8 @@ import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .softmax import FusedScaleMaskSoftmax from .softmax import FusedScaleMaskSoftmax
from ..constants import (AttnTypes, TE_DType, QKVLayout, AttnBiasType, AttnMaskType, from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
FusedAttnBackend, dist_group_type) dist_group_type)
from ..cpp_extensions import ( from ..cpp_extensions import (
fused_attn_fwd_qkvpacked, fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
...@@ -28,7 +28,6 @@ from ..distributed import get_tp_group_and_world_size, track_rng_state ...@@ -28,7 +28,6 @@ from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide from ..utils import attention_mask_func, divide
from ..recompute import recompute from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"] __all__ = ["DotProductAttention", "MultiHeadAttention"]
...@@ -168,7 +167,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -168,7 +167,7 @@ class DotProductAttention(paddle.nn.Layer):
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_type = attention_type self.attention_type = attention_type
self.qkv_layout = "qkv_interleaved" if attention_type == "self" else "kv_interleaved" self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd"
self.backend = backend self.backend = backend
...@@ -237,7 +236,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -237,7 +236,7 @@ class DotProductAttention(paddle.nn.Layer):
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1] max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend( self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
QKVLayout[self.qkv_layout], AttnBiasType[core_attention_bias_type], tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv, AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv,
query_layer.shape[-1]) query_layer.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment