Unverified Commit 94c57e4d authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Error handle for non-sm80/sm90 GPUs when using fused attention (#393)



* Fused attention kernel only supports sm80 and sm90
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update transformer_engine/jax/csrc/modules.cpp
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* arbitary fused kernel supports sm86/sm89 after 8.9.3
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Skip sm70
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Forward is_fused_attn_kernel_available to cpp backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove cpp is_fused_attn_available API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b90a8d3a
...@@ -19,8 +19,9 @@ from flax.linen import make_causal_mask ...@@ -19,8 +19,9 @@ from flax.linen import make_causal_mask
from jax import value_and_grad, jit from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability
# Type annotations # Type annotations
Array = jnp.ndarray Array = jnp.ndarray
...@@ -146,8 +147,6 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -146,8 +147,6 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs) return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)
@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
...@@ -159,13 +158,14 @@ class TestSelfFusedAttn(): ...@@ -159,13 +158,14 @@ class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn""" """Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod @staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, pad_ratio): def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
# Arbitrary seqlen backend has a limited spec for now head_dim, pad_ratio):
# No bias, only causal mask, and no variable seqlen if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
if (s > 512 or backend == Backend.Arbitrary) and (attn_bias_type != AttnBiasType.NO_BIAS or pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
attn_mask_type != AttnMaskType.CAUSAL_MASK
or pad_ratio != 0): if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type,
pytest.skip("Unsupported inputs combination.") dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio): dropout_probability, dtype, is_training, pad_ratio):
...@@ -174,6 +174,9 @@ class TestSelfFusedAttn(): ...@@ -174,6 +174,9 @@ class TestSelfFusedAttn():
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
backend=backend, backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
head_dim=d,
pad_ratio=pad_ratio) pad_ratio=pad_ratio)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -361,7 +364,7 @@ class TestSelfFusedAttn(): ...@@ -361,7 +364,7 @@ class TestSelfFusedAttn():
jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:])) jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))
@pytest.mark.skipif(not is_fused_attn_kernel_available(), @pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES) @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
......
...@@ -44,7 +44,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -44,7 +44,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false; bool flag_m512 = false;
bool flag_arb = false; bool flag_arb = false;
if ((sm_arch_ >= 80) if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (head_dim == 64) && (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
...@@ -55,7 +55,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -55,7 +55,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) { || (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true; flag_m512 = true;
} }
if ((sm_arch_ >= 80) if (
#if (CUDNN_VERSION >= 8903)
(sm_arch_ >= 80)
#else
(sm_arch_ == 80 || sm_arch_ == 90)
#endif
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128)) && ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
......
...@@ -68,6 +68,33 @@ def jax_dtype_to_te_dtype(jax_dtype): ...@@ -68,6 +68,33 @@ def jax_dtype_to_te_dtype(jax_dtype):
raise ValueError(f"Not support the {jax_dtype=}") raise ValueError(f"Not support the {jax_dtype=}")
@dataclass(frozen=True)
class FusedAttnHelper:
"""
Helper for the fused attention backend
"""
q_type: jnp.dtype
kv_type: jnp.dtype
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
dropout_probability: float
max_seqlen_q: int
max_seqlen_kv: int
head_dim: int
def is_fused_attn_kernel_available(self):
"""Check if there is available fused attention kernel"""
return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, self.attn_bias_type, self.attn_mask_type,
self.dropout_probability, self.max_seqlen_q, self.max_seqlen_kv, self.head_dim)
def merge_named_shape(base, new): def merge_named_shape(base, new):
""" """
merge named shape(ie, dict), no key conflict merge named shape(ie, dict), no key conflict
...@@ -2053,10 +2080,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2053,10 +2080,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
output_shape = (batch, max_seqlen, num_head, head_dim) output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype output_dtype = qkv_dtype
backend = transformer_engine_jax.get_fused_attn_backend( backend = FusedAttnHelper(qkv_dtype, qkv_dtype, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_dtype), jax_dtype_to_te_dtype(qkv_dtype), dropout_probability, max_seqlen, max_seqlen,
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, attn_bias_type, attn_mask_type, head_dim).get_fused_attn_backend()
dropout_probability, max_seqlen, max_seqlen, head_dim)
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen) softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
......
...@@ -63,7 +63,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -63,7 +63,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_fused_attn_backend", &GetFusedAttnBackend);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "common/common.h" #include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
...@@ -89,16 +90,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -89,16 +90,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
bias_type, mask_type, dtype, is_training}); bias_type, mask_type, dtype, is_training});
} }
bool IsFusedAttnKernelAvailable() {
#if (CUDNN_VERSION >= 8901)
auto major = cudaDevicePropertiesManager::Instance().GetMajor();
// Fused attention requires at least Ampere
return major >= 8;
#else
return false;
#endif
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) { void *output) {
auto input_shape = std::vector<size_t>{rows, cols}; auto input_shape = std::vector<size_t>{rows, cols};
......
...@@ -114,8 +114,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -114,8 +114,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, DType dtype, bool is_training);
bool IsFusedAttnKernelAvailable();
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
......
...@@ -414,7 +414,21 @@ class MultiHeadAttention(nn.Module): ...@@ -414,7 +414,21 @@ class MultiHeadAttention(nn.Module):
return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype) return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)
first_sharding_type, second_sharding_type = infer_sharding_type() # TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype) canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
...@@ -427,11 +441,16 @@ class MultiHeadAttention(nn.Module): ...@@ -427,11 +441,16 @@ class MultiHeadAttention(nn.Module):
def _check_head_dim(head_dim): def _check_head_dim(head_dim):
return head_dim in [64, 128] return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype,
attn_bias_type, attn_mask_type,
self.dropout_rate, q_seqlen,
kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \ _check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \ _check_head_dim(self.head_dim) and \
is_fused_attn_kernel_available() and \ has_fused_attn_kernel and \
enable_fused_attn enable_fused_attn
if enable_fused_attn and not use_fused_attn: if enable_fused_attn and not use_fused_attn:
...@@ -454,12 +473,14 @@ class MultiHeadAttention(nn.Module): ...@@ -454,12 +473,14 @@ class MultiHeadAttention(nn.Module):
f"but got {kv_seqlen=}, " f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim): if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, " reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not is_fused_attn_kernel_available(): if not has_fused_attn_kernel:
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, " reason += "no fused attention kernel is available, "
warnings.warn( warnings.warn(
f"Fused attention is not enabled, " \ f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention") f"{reason}fall back to unfused attention.")
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
...@@ -629,22 +650,6 @@ class MultiHeadAttention(nn.Module): ...@@ -629,22 +650,6 @@ class MultiHeadAttention(nn.Module):
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
if inputs_q is inputs_kv: if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
......
...@@ -8,10 +8,10 @@ from functools import partial ...@@ -8,10 +8,10 @@ from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta from .sharding import get_fused_attn_sharding_meta
...@@ -22,13 +22,6 @@ jax.config.update('experimental_xmap_spmd_lowering', True) ...@@ -22,13 +22,6 @@ jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def is_fused_attn_kernel_available():
"""
To check whether the fused attention kernel is available
"""
return transformer_engine_jax.is_fused_attn_kernel_available()
class AttnBiasType(Enum): class AttnBiasType(Enum):
"""Attention Bias Type.""" """Attention Bias Type."""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
...@@ -43,6 +36,16 @@ class AttnMaskType(Enum): ...@@ -43,6 +36,16 @@ class AttnMaskType(Enum):
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
def is_fused_attn_kernel_available(q_type, kv_type, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
"""
To check whether the fused attention kernel is available
"""
return FusedAttnHelper(q_type, kv_type, attn_bias_type.value, attn_mask_type.value,
dropout_probability, max_seqlen_q, max_seqlen_kv,
head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray, bias: jnp.ndarray,
mask: jnp.ndarray, mask: jnp.ndarray,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment