Unverified Commit 4ae34765 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Debug CI tests on Ada (#397)



* Debug PyTorch and Paddle tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only run Paddle layer tests with cuDNN fMHA on supported archs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug PyTorch fMHA tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Reduce JAX FP8 GEMM sizes

Avoid split-k kernels on Ada.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable JAX fused self-attention test on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run supported fused attention tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run supported fused attention JAX tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Enable Paddle fused attention on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update reference scale calculation in TensorFlow test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Restore backend support to reference FP8 attention impl in PyT test

Review suggestion from @cyanguwa
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Paddle tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Loosen tolerances for Paddle attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Assume causal mask implies equal seqlens in Paddle attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 1f4c3979
......@@ -24,8 +24,13 @@ from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
(2048, 1024, 2048),
(2048, 2048, 1024),
(2048, 1024, 1024),
]
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
......
......@@ -167,6 +167,12 @@ class TestSelfFusedAttn():
dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
"""Setup the test inputs"""
......
......@@ -5,14 +5,16 @@
import math
import os
import pytest
from utils import assert_allclose
import paddle
import pytest
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
from transformer_engine.common.recipe import DelayedScaling
from utils import is_fused_attention_supported
is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
......@@ -614,8 +616,6 @@ class TestLayerNormMLP:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
......@@ -630,8 +630,21 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
paddle.set_default_dtype(math_dtype)
rtol = 1e-4
atol = 2e-2
head_size = hidden_size // num_heads
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=head_size,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved" if attn_type == "self" else "kv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, 3, num_heads,
......@@ -727,8 +740,6 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
......@@ -749,6 +760,19 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol = 5e-2
eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
......@@ -892,8 +916,6 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol=0.5)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
......@@ -916,6 +938,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol = 6e-2
eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="kv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
......
......@@ -6,10 +6,9 @@
import struct
import numpy as np
import pytest
import paddle
import paddle.nn.functional as F
from utils import assert_allclose, create_fp8_meta
import pytest
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
......@@ -45,6 +44,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
......@@ -300,7 +306,7 @@ class TestGemm:
actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
None, None, False)
assert_allclose(actual_out, ref_out)
assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
......@@ -332,8 +338,8 @@ class TestGemm:
"""
Test "TN" FP8 GEMM
"""
min_val = -8
max_val = 8
min_val = -4
max_val = 4
fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1)
......@@ -582,31 +588,38 @@ class TestFusedAttn:
else:
self.kv = _random(self.kv_shape)
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.q_actual_seqlen = None
if self.is_causal_masking:
self.q_actual_seqlen = np.full(
self.batch_size,
self.q_seqlen,
dtype=np.int32,
)
else:
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.zeros(
self.attn_mask = np.ones(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32,
)
for i in range(0, self.batch_size):
self.attn_mask[i, 0, 0:self.q_actual_seqlen[i], 0:self.kv_actual_seqlen[i],] = 1
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
col_beg, col_end = 1, self.q_actual_seqlen[i]
for row in range(0, self.q_actual_seqlen[i]):
self.attn_mask[i, 0, row, col_beg:col_end] = 0
col_beg += 1
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, :j+1] = 0
else:
for i in range(0, self.batch_size):
self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
......@@ -628,10 +641,12 @@ class TestFusedAttn:
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
attn_mask = (paddle.cast(attn_mask, self.dtype) - 1.0) * 1e4
attn_mask_out = qk_out + attn_mask
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool')
attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
attn_mask_out = paddle.cast(attn_mask_out, 'float32')
softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype)
if self.dropout_prob:
dropout_out = F.dropout(
......@@ -667,9 +682,21 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
fused_attention_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen if (
self.q_seqlen <= 512
and self.kv_seqlen <= 512) else tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
qkv_layout = (
"qkv_interleaved"
if self.attn_mode == "self_attn"
else "kv_interleaved"
)
fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
......@@ -739,8 +766,6 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
......@@ -748,6 +773,17 @@ class TestFusedAttn:
"""
test self attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s,
kv_seqlen=s,
dtype=dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
......@@ -756,14 +792,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s_q,
kv_seqlen=s_kv,
dtype=dtype,
dropout=0.0,
qkv_layout="kv_interleaved",
bias_type="no_bias",
mask_type="padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
......@@ -772,8 +817,6 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True])
......@@ -781,6 +824,17 @@ class TestFusedAttn:
"""
test flash attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s,
kv_seqlen=s,
dtype=dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
......
......@@ -4,15 +4,23 @@
"""Utils for testing"""
import random
import numpy as np
from typing import Union
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import (
TE_DType,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta
import transformer_engine_paddle as tex
def create_fp8_meta(num_gemms=1, amax_history_len=10):
......@@ -92,3 +100,55 @@ def set_random_seed(seed):
tracker.add("local_seed", local_seed)
paddle.seed(global_seed)
def get_fused_attention_backend(
head_size: int,
q_seqlen: int,
kv_seqlen: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
"""Get cuDNN fused attention backend for attention config"""
if isinstance(dtype, str):
dtype = dict(
float32=paddle.float32,
bfloat16=paddle.bfloat16,
float16=paddle.float16,
)[dtype]
return tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[bias_type],
AttnMaskType[mask_type],
dropout,
q_seqlen,
kv_seqlen,
head_size,
)
def is_fused_attention_supported(
head_size: int,
q_seqlen: int,
kv_seqlen: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> bool:
"""Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend(
head_size=head_size,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=dtype,
dropout=dropout,
qkv_layout=qkv_layout,
bias_type=bias_type,
mask_type=mask_type,
)
return backend != FusedAttnBackend["No_Backend"]
......@@ -2,23 +2,46 @@
#
# See LICENSE for license information.
import torch
from importlib.metadata import version
import os
from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch.attention import (
DotProductAttention,
RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_qkvpacked,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
get_device_compute_capability,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention, RotaryPositionEmbedding
import os
import transformer_engine_extensions as tex
from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]
......@@ -63,32 +86,93 @@ param_types_lean = [torch.bfloat16]
batch_sizes_lean = [2]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
def _is_fused_attention_supported(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout: str = "sbh3d",
bias_type: str = "no_bias",
) -> bool:
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.seq_len,
config.seq_len,
config.head_dim,
)
return backend != FusedAttnBackend["No_Backend"]
def _is_flash_attention_supported(bias_type: str = "no_bias") -> bool:
if get_device_compute_capability() < (8, 0):
return False
if bias_type != "no_bias":
return False
return True
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention"""
"""Test DotProductAttention module with different backends"""
# Get configs
config = model_configs[model]
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(
config,
dtype,
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
dtype,
bs,
config,
"UnfusedDotProductAttention",
ckpt_attn,
bias_type,
)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"FusedAttention",
ckpt_attn,
bias_type,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias":
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"FlashAttention",
ckpt_attn,
bias_type,
)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
......@@ -155,9 +239,7 @@ qkv_layouts = [
]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.skipif(
_cudnn_version >= [8,9,5], reason="cuDNN 8.9.5+ is required.")
_cudnn_version < [8,9,5], reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
......@@ -166,23 +248,39 @@ qkv_layouts = [
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
# Get configs
config = model_configs_lean[model]
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(config, dtype)
flash_attn_supported = _is_flash_attention_supported()
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
for i in range(len(flash_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], **tols)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
......@@ -272,8 +370,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
return op, (inp[0].grad, inp[1].grad, inp[2].grad)
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs_lean.keys())
......@@ -284,22 +380,61 @@ def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
# Get configs
config = model_configs_lean[model]
tols = dict(atol=5e-1, rtol=5e-2)
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", bias_type, fused_qkv_params, RoPE)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention", bias_type, fused_qkv_params, RoPE)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(
config,
dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", bias_type, fused_qkv_params, RoPE)
dtype,
bs,
config,
"UnfusedDotProductAttention",
bias_type,
fused_qkv_params,
RoPE,
)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"FusedAttention",
bias_type,
fused_qkv_params,
RoPE,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"FlashAttention",
bias_type,
fused_qkv_params,
RoPE,
)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):
......@@ -385,15 +520,12 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
return op, inp.grad
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
FlashAttention or UnfusedDotProductAttention backend"""
config = model_configs_lean[model]
def find_factors(x):
......@@ -403,6 +535,10 @@ def test_transformer_layer_gqa(dtype, bs, model):
f.append(i)
return f
# Skip if only unfused backend is supported
if not (_flash_attn_2_available and _is_flash_attention_supported()):
pytest.skip("FlashAttention does not support this model config")
num_querys_per_gqa_group = find_factors(config.num_attention_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
......@@ -419,8 +555,11 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
......@@ -495,25 +634,48 @@ param_types_fp8 = [torch.float16]
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
"""Test DotProductAttention module with FP8,
using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""
"""Test FP8 dot-product attention with different backends
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from
cpp_extensions. UnfusedDotProductAttention uses plain PyTorch
operations.
"""
config = model_configs_fp8[model]
# Skip if not supported
if not _is_fused_attention_supported(config, dtype):
pytest.skip("FusedAttention does not support this model config")
# Run dot-product attention with different backends
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, bs, config, "FusedAttention")
dtype,
bs,
config,
"FusedAttention"
)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention")
dtype,
bs,
config,
"UnfusedDotProductAttention",
)
atol, rtol = (2.5e-2, 2.5e-2)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
# Check that results match
tols = dict(atol=2.5e-2, rtol=2.5e-2)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, bs, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
......@@ -585,21 +747,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
return op, inp.grad
from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd,
fused_attn_bwd,
FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
......@@ -864,7 +1011,7 @@ class DPA_FP8(TransformerEngineBaseModule):
self.head_dim = config.head_dim
self.fast_zero_fill = True
self.qkv_weight = Parameter(
self.qkv_weight = torch.nn.Parameter(
torch.empty(
self.hidden_size * 3,
self.hidden_size,
......@@ -873,7 +1020,7 @@ class DPA_FP8(TransformerEngineBaseModule):
)
)
self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = Parameter(
self.qkv_bias = torch.nn.Parameter(
torch.empty(
self.hidden_size * 3,
device=torch.cuda.current_device(),
......
......@@ -170,10 +170,7 @@ class DotProductAttention(paddle.nn.Layer):
self.backend = backend
arch = paddle.device.cuda.get_device_capability()
self.is_fused_attn_supported = arch in ((8, 0), (9, 0))
self.use_fused_attention = (int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.is_fused_attn_supported)
self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
if not self.use_fused_attention and backend == 'transformer_engine':
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
......@@ -231,7 +228,9 @@ class DotProductAttention(paddle.nn.Layer):
Whether to use the fast path to set output tensors to 0 or not.
"""
if self.backend == 'transformer_engine':
backend = self.backend
if backend == 'transformer_engine':
max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend(
......@@ -247,16 +246,16 @@ class DotProductAttention(paddle.nn.Layer):
return self._te_forward(query_layer, key_value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = 'paddle'
backend = 'paddle'
self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
attention_mask_func,
backend=self.backend)
if self.backend == 'paddle':
backend=backend)
if backend == 'paddle':
if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.")
return self._pd_forward(query_layer, key_value_layer, attention_mask)
raise AttributeError(f"Backend {self.backend} is not supported.")
raise AttributeError(f"Backend {backend} is not supported.")
def _te_forward(
self,
......
......@@ -1619,7 +1619,7 @@ class FusedAttention(torch.nn.Module):
self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available
and get_device_compute_capability() == 9.0)
and get_device_compute_capability() == (9, 0))
def forward(
self,
......@@ -1849,7 +1849,7 @@ class DotProductAttention(torch.nn.Module):
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= 8.0
and self.device_compute_capability >= (8, 0)
)
if _flash_attn_2_available and self.deterministic:
self.use_flash_attention = False
......@@ -1861,7 +1861,7 @@ class DotProductAttention(torch.nn.Module):
self.use_fused_attention = (
int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= 8.0
and self.device_compute_capability >= (8, 0)
)
assert (
......@@ -2091,9 +2091,12 @@ class DotProductAttention(torch.nn.Module):
# Filter: Device and dimensions.
if key_layer.shape[-1] > 64:
if self.device_compute_capability in (8.6, 8.7):
if self.device_compute_capability in ((8, 6), (8, 7)):
use_flash_attention = False
elif not _flash_attn_2_available and self.device_compute_capability == 8.9:
elif (
not _flash_attn_2_available
and self.device_compute_capability == (8, 9)
):
use_flash_attention = False
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
......
......@@ -22,9 +22,9 @@ __all__ = ["fp8_autocast"]
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, ""
if get_device_compute_capability() < 8.9: # pre-ada
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
......
......@@ -8,11 +8,10 @@ from typing import Any, Callable, Optional, Tuple
import torch
def get_device_compute_capability() -> float:
"""Returns the cuda compute capability of current GPU"""
major = torch.cuda.get_device_properties(torch.cuda.current_device()).major
minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor
return major + minor / 10
def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device())
return (props.major, props.minor)
def attention_mask_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment