Unverified Commit 4ae34765 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Debug CI tests on Ada (#397)



* Debug PyTorch and Paddle tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only run Paddle layer tests with cuDNN fMHA on supported archs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug PyTorch fMHA tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Reduce JAX FP8 GEMM sizes

Avoid split-k kernels on Ada.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable JAX fused self-attention test on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run supported fused attention tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run supported fused attention JAX tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Enable Paddle fused attention on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update reference scale calculation in TensorFlow test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Restore backend support to reference FP8 attention impl in PyT test

Review suggestion from @cyanguwa
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Paddle tests on Ada
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Loosen tolerances for Paddle attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Assume causal mask implies equal seqlens in Paddle attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 1f4c3979
...@@ -24,8 +24,13 @@ from transformer_engine.jax.fp8 import is_fp8_available ...@@ -24,8 +24,13 @@ from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp from transformer_engine.jax.mlp import fp8_ln_mlp
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [
(16384, 1024, 1024)] (256, 256, 512),
(32, 32, 32),
(2048, 1024, 2048),
(2048, 2048, 1024),
(2048, 1024, 1024),
]
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)] FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)] LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
......
...@@ -167,6 +167,12 @@ class TestSelfFusedAttn(): ...@@ -167,6 +167,12 @@ class TestSelfFusedAttn():
dropout_probability, s, s, head_dim): dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio): dropout_probability, dtype, is_training, pad_ratio):
"""Setup the test inputs""" """Setup the test inputs"""
......
...@@ -5,14 +5,16 @@ ...@@ -5,14 +5,16 @@
import math import math
import os import os
import pytest
from utils import assert_allclose from utils import assert_allclose
import paddle import paddle
import pytest
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine.paddle as te import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
from transformer_engine.common.recipe import DelayedScaling
from utils import is_fused_attention_supported
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
...@@ -614,8 +616,6 @@ class TestLayerNormMLP: ...@@ -614,8 +616,6 @@ class TestLayerNormMLP:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) @pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
...@@ -630,8 +630,21 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -630,8 +630,21 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
paddle.set_default_dtype(math_dtype) paddle.set_default_dtype(math_dtype)
rtol = 1e-4 rtol = 1e-4
atol = 2e-2 atol = 2e-2
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=head_size,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved" if attn_type == "self" else "kv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
self_attn_qkv_input = paddle.normal(mean=0.0, self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02, std=0.02,
shape=(bs, q_seqlen, 3, num_heads, shape=(bs, q_seqlen, 3, num_heads,
...@@ -727,8 +740,6 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -727,8 +740,6 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
...@@ -749,6 +760,19 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -749,6 +760,19 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol = 5e-2 atol = 5e-2
eps = 1e-3 eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
...@@ -892,8 +916,6 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -892,8 +916,6 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol=0.5) atol=0.5)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
...@@ -916,6 +938,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -916,6 +938,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
atol = 6e-2 atol = 6e-2
eps = 1e-3 eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="kv_interleaved",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype) encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
......
...@@ -6,10 +6,9 @@ ...@@ -6,10 +6,9 @@
import struct import struct
import numpy as np import numpy as np
import pytest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from utils import assert_allclose, create_fp8_meta import pytest
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
...@@ -45,6 +44,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available ...@@ -45,6 +44,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -300,7 +306,7 @@ class TestGemm: ...@@ -300,7 +306,7 @@ class TestGemm:
actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
None, None, False) None, None, False)
assert_allclose(actual_out, ref_out) assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
...@@ -332,8 +338,8 @@ class TestGemm: ...@@ -332,8 +338,8 @@ class TestGemm:
""" """
Test "TN" FP8 GEMM Test "TN" FP8 GEMM
""" """
min_val = -8 min_val = -4
max_val = 8 max_val = 4
fp8_dtype = tex.DType.kFloat8E4M3 fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.float32 out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1) fp8_meta = create_fp8_meta(num_gemms=1)
...@@ -582,31 +588,38 @@ class TestFusedAttn: ...@@ -582,31 +588,38 @@ class TestFusedAttn:
else: else:
self.kv = _random(self.kv_shape) self.kv = _random(self.kv_shape)
self.q_actual_seqlen = np.random.randint( self.q_actual_seqlen = None
low=20, if self.is_causal_masking:
high=self.q_seqlen, self.q_actual_seqlen = np.full(
size=(self.batch_size,), self.batch_size,
dtype=np.int32, self.q_seqlen,
) dtype=np.int32,
)
else:
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen) self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0) self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen) self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0) self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.zeros( self.attn_mask = np.ones(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32, dtype=np.int32,
) )
for i in range(0, self.batch_size): if self.is_causal_masking:
self.attn_mask[i, 0, 0:self.q_actual_seqlen[i], 0:self.kv_actual_seqlen[i],] = 1 assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
if self.is_causal_masking: for j in range(self.q_actual_seqlen[i]):
assert attn_mode == "self_attn", "only support causal masking for self attention" self.attn_mask[i, :, j, :j+1] = 0
col_beg, col_end = 1, self.q_actual_seqlen[i] else:
for row in range(0, self.q_actual_seqlen[i]): for i in range(0, self.batch_size):
self.attn_mask[i, 0, row, col_beg:col_end] = 0 self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
col_beg += 1
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype) self.dout = paddle.to_tensor(dout, dtype=self.dtype)
...@@ -628,10 +641,12 @@ class TestFusedAttn: ...@@ -628,10 +641,12 @@ class TestFusedAttn:
transpose_y=True, transpose_y=True,
) )
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool')
attn_mask = (paddle.cast(attn_mask, self.dtype) - 1.0) * 1e4 attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = qk_out + attn_mask attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
attn_mask_out = paddle.cast(attn_mask_out, 'float32')
softmax_out = F.softmax(attn_mask_out) softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype)
if self.dropout_prob: if self.dropout_prob:
dropout_out = F.dropout( dropout_out = F.dropout(
...@@ -667,9 +682,21 @@ class TestFusedAttn: ...@@ -667,9 +682,21 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
fused_attention_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen if ( qkv_layout = (
self.q_seqlen <= 512 "qkv_interleaved"
and self.kv_seqlen <= 512) else tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen if self.attn_mode == "self_attn"
else "kv_interleaved"
)
fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
...@@ -739,8 +766,6 @@ class TestFusedAttn: ...@@ -739,8 +766,6 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad return out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False]) @pytest.mark.parametrize('is_causal_masking', [True, False])
...@@ -748,6 +773,17 @@ class TestFusedAttn: ...@@ -748,6 +773,17 @@ class TestFusedAttn:
""" """
test self attention forward + backward test self attention forward + backward
""" """
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s,
kv_seqlen=s,
dtype=dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
...@@ -756,14 +792,23 @@ class TestFusedAttn: ...@@ -756,14 +792,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES) @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
""" """
test cross attention forward + backward test cross attention forward + backward
""" """
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s_q,
kv_seqlen=s_kv,
dtype=dtype,
dropout=0.0,
qkv_layout="kv_interleaved",
bias_type="no_bias",
mask_type="padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
...@@ -772,8 +817,6 @@ class TestFusedAttn: ...@@ -772,8 +817,6 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True]) @pytest.mark.parametrize('is_causal_masking', [True])
...@@ -781,6 +824,17 @@ class TestFusedAttn: ...@@ -781,6 +824,17 @@ class TestFusedAttn:
""" """
test flash attention forward + backward test flash attention forward + backward
""" """
if not is_fused_attention_supported(
head_size=d,
q_seqlen=s,
kv_seqlen=s,
dtype=dtype,
dropout=0.0,
qkv_layout="qkv_interleaved",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
......
...@@ -4,15 +4,23 @@ ...@@ -4,15 +4,23 @@
"""Utils for testing""" """Utils for testing"""
import random import random
import numpy as np from typing import Union
import numpy as np
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import (
TE_DType,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta from transformer_engine.paddle.fp8 import FP8TensorMeta
import transformer_engine_paddle as tex
def create_fp8_meta(num_gemms=1, amax_history_len=10): def create_fp8_meta(num_gemms=1, amax_history_len=10):
...@@ -92,3 +100,55 @@ def set_random_seed(seed): ...@@ -92,3 +100,55 @@ def set_random_seed(seed):
tracker.add("local_seed", local_seed) tracker.add("local_seed", local_seed)
paddle.seed(global_seed) paddle.seed(global_seed)
def get_fused_attention_backend(
head_size: int,
q_seqlen: int,
kv_seqlen: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
"""Get cuDNN fused attention backend for attention config"""
if isinstance(dtype, str):
dtype = dict(
float32=paddle.float32,
bfloat16=paddle.bfloat16,
float16=paddle.float16,
)[dtype]
return tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[bias_type],
AttnMaskType[mask_type],
dropout,
q_seqlen,
kv_seqlen,
head_size,
)
def is_fused_attention_supported(
head_size: int,
q_seqlen: int,
kv_seqlen: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> bool:
"""Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend(
head_size=head_size,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=dtype,
dropout=dropout,
qkv_layout=qkv_layout,
bias_type=bias_type,
mask_type=mask_type,
)
return backend != FusedAttnBackend["No_Backend"]
...@@ -2,23 +2,46 @@ ...@@ -2,23 +2,46 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import torch from importlib.metadata import version
import os
from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
import pytest import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch.attention import (
DotProductAttention,
RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_qkvpacked,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
get_device_compute_capability,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine_extensions as tex
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention, RotaryPositionEmbedding
import os
from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')] _cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]
...@@ -63,32 +86,93 @@ param_types_lean = [torch.bfloat16] ...@@ -63,32 +86,93 @@ param_types_lean = [torch.bfloat16]
batch_sizes_lean = [2] batch_sizes_lean = [2]
@pytest.mark.skipif( def _is_fused_attention_supported(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") config: ModelConfig,
dtype: torch.dtype,
qkv_layout: str = "sbh3d",
bias_type: str = "no_bias",
) -> bool:
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.seq_len,
config.seq_len,
config.head_dim,
)
return backend != FusedAttnBackend["No_Backend"]
def _is_flash_attention_supported(bias_type: str = "no_bias") -> bool:
if get_device_compute_capability() < (8, 0):
return False
if bias_type != "no_bias":
return False
return True
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False]) @pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with three backends, """Test DotProductAttention module with different backends"""
FlashAttention, FusedAttention and UnfusedDotProductAttention"""
# Get configs
config = model_configs[model] config = model_configs[model]
if bias_type == "no_bias": tols = dict(atol=5e-3, rtol=5e-3)
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( if dtype == torch.bfloat16:
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) tols = dict(atol=2.5e-2, rtol=2.5e-2)
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type) # Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(
config,
dtype,
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type) dtype,
bs,
config,
"UnfusedDotProductAttention",
ckpt_attn,
bias_type,
)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"FusedAttention",
ckpt_attn,
bias_type,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) # FlashAttention backend
if bias_type == "no_bias": if flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) dtype,
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) bs,
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) config,
"FlashAttention",
ckpt_attn,
bias_type,
)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
...@@ -155,9 +239,7 @@ qkv_layouts = [ ...@@ -155,9 +239,7 @@ qkv_layouts = [
] ]
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") _cudnn_version < [8,9,5], reason="cuDNN 8.9.5+ is required.")
@pytest.mark.skipif(
_cudnn_version >= [8,9,5], reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", model_configs_lean.keys())
...@@ -166,23 +248,39 @@ qkv_layouts = [ ...@@ -166,23 +248,39 @@ qkv_layouts = [
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout): def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
# Get configs
config = model_configs_lean[model] config = model_configs_lean[model]
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(config, dtype)
flash_attn_supported = _is_flash_attention_supported()
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout( # UnfusedDotProductAttention backend
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout( unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt) dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3) # FlashAttention backend
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) if flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
for i in range(len(flash_attn_bwd)): torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol) for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol) torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], **tols)
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
...@@ -272,8 +370,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): ...@@ -272,8 +370,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
return op, (inp[0].grad, inp[1].grad, inp[2].grad) return op, (inp[0].grad, inp[1].grad, inp[2].grad)
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", model_configs_lean.keys())
...@@ -284,22 +380,61 @@ def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE): ...@@ -284,22 +380,61 @@ def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
# Get configs
config = model_configs_lean[model] config = model_configs_lean[model]
tols = dict(atol=5e-1, rtol=5e-2)
if bias_type == "no_bias": # Skip if only unfused backend is supported
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( fused_attn_supported = _is_fused_attention_supported(
dtype, bs, config, "FlashAttention", bias_type, fused_qkv_params, RoPE) config,
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( dtype,
dtype, bs, config, "FusedAttention", bias_type, fused_qkv_params, RoPE) qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", bias_type, fused_qkv_params, RoPE) dtype,
bs,
config,
"UnfusedDotProductAttention",
bias_type,
fused_qkv_params,
RoPE,
)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"FusedAttention",
bias_type,
fused_qkv_params,
RoPE,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
atol, rtol = (5e-1, 5e-2) # FlashAttention backend
if bias_type == "no_bias": if flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) dtype,
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) bs,
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) config,
"FlashAttention",
bias_type,
fused_qkv_params,
RoPE,
)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE): def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):
...@@ -385,15 +520,12 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para ...@@ -385,15 +520,12 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
return op, inp.grad return op, inp.grad
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", model_configs_lean.keys())
def test_transformer_layer_gqa(dtype, bs, model): def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention or UnfusedDotProductAttention backend"""
config = model_configs_lean[model] config = model_configs_lean[model]
def find_factors(x): def find_factors(x):
...@@ -403,6 +535,10 @@ def test_transformer_layer_gqa(dtype, bs, model): ...@@ -403,6 +535,10 @@ def test_transformer_layer_gqa(dtype, bs, model):
f.append(i) f.append(i)
return f return f
# Skip if only unfused backend is supported
if not (_flash_attn_2_available and _is_flash_attention_supported()):
pytest.skip("FlashAttention does not support this model config")
num_querys_per_gqa_group = find_factors(config.num_attention_heads) num_querys_per_gqa_group = find_factors(config.num_attention_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group: for num_q_per_gqa_group in num_querys_per_gqa_group:
...@@ -419,8 +555,11 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -419,8 +555,11 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn( inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
...@@ -495,25 +634,48 @@ param_types_fp8 = [torch.float16] ...@@ -495,25 +634,48 @@ param_types_fp8 = [torch.float16]
@pytest.mark.parametrize("bs", batch_sizes_fp8) @pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys()) @pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model): def test_dpa_fp8(dtype, bs, model):
"""Test DotProductAttention module with FP8, """Test FP8 dot-product attention with different backends
using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from
cpp_extensions. UnfusedDotProductAttention uses plain PyTorch
operations.
"""
config = model_configs_fp8[model] config = model_configs_fp8[model]
# Skip if not supported
if not _is_fused_attention_supported(config, dtype):
pytest.skip("FusedAttention does not support this model config")
# Run dot-product attention with different backends
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8( fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, bs, config, "FusedAttention") dtype,
bs,
config,
"FusedAttention"
)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention") dtype,
bs,
config,
"UnfusedDotProductAttention",
)
atol, rtol = (2.5e-2, 2.5e-2) # Check that results match
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) tols = dict(atol=2.5e-2, rtol=2.5e-2)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, bs, config, backend): def _run_dpa_fp8(dtype, bs, config, backend):
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn( inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim, bs * config.seq_len, config.num_attention_heads * config.head_dim,
...@@ -585,21 +747,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -585,21 +747,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
return op, inp.grad return op, inp.grad
from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd,
fused_attn_bwd,
FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB _CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = False _2X_ACC_DGRAD = False
...@@ -864,7 +1011,7 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -864,7 +1011,7 @@ class DPA_FP8(TransformerEngineBaseModule):
self.head_dim = config.head_dim self.head_dim = config.head_dim
self.fast_zero_fill = True self.fast_zero_fill = True
self.qkv_weight = Parameter( self.qkv_weight = torch.nn.Parameter(
torch.empty( torch.empty(
self.hidden_size * 3, self.hidden_size * 3,
self.hidden_size, self.hidden_size,
...@@ -873,7 +1020,7 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -873,7 +1020,7 @@ class DPA_FP8(TransformerEngineBaseModule):
) )
) )
self.fp8_weight_shapes.append(self.qkv_weight.shape) self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = Parameter( self.qkv_bias = torch.nn.Parameter(
torch.empty( torch.empty(
self.hidden_size * 3, self.hidden_size * 3,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
......
...@@ -170,10 +170,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -170,10 +170,7 @@ class DotProductAttention(paddle.nn.Layer):
self.backend = backend self.backend = backend
arch = paddle.device.cuda.get_device_capability() self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
self.is_fused_attn_supported = arch in ((8, 0), (9, 0))
self.use_fused_attention = (int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.is_fused_attn_supported)
if not self.use_fused_attention and backend == 'transformer_engine': if not self.use_fused_attention and backend == 'transformer_engine':
warnings.warn("Fused attention is not enabled, falling back to Paddle backend") warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
...@@ -231,7 +228,9 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -231,7 +228,9 @@ class DotProductAttention(paddle.nn.Layer):
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
if self.backend == 'transformer_engine': backend = self.backend
if backend == 'transformer_engine':
max_s_q = query_layer.shape[1] max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1] max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend( self.fused_attention_backend = tex.get_fused_attn_backend(
...@@ -247,16 +246,16 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -247,16 +246,16 @@ class DotProductAttention(paddle.nn.Layer):
return self._te_forward(query_layer, key_value_layer, attention_mask, return self._te_forward(query_layer, key_value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero) core_attention_bias_type, core_attention_bias, set_zero)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend") warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = 'paddle' backend = 'paddle'
self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type, self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
attention_mask_func, attention_mask_func,
backend=self.backend) backend=backend)
if self.backend == 'paddle': if backend == 'paddle':
if core_attention_bias_type != "no_bias": if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. " warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.") "Bias will be ignored.")
return self._pd_forward(query_layer, key_value_layer, attention_mask) return self._pd_forward(query_layer, key_value_layer, attention_mask)
raise AttributeError(f"Backend {self.backend} is not supported.") raise AttributeError(f"Backend {backend} is not supported.")
def _te_forward( def _te_forward(
self, self,
......
...@@ -1619,7 +1619,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1619,7 +1619,7 @@ class FusedAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1" self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available and _flash_attn_2_available
and get_device_compute_capability() == 9.0) and get_device_compute_capability() == (9, 0))
def forward( def forward(
self, self,
...@@ -1849,7 +1849,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1849,7 +1849,7 @@ class DotProductAttention(torch.nn.Module):
self.use_flash_attention = ( self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= (8, 0)
) )
if _flash_attn_2_available and self.deterministic: if _flash_attn_2_available and self.deterministic:
self.use_flash_attention = False self.use_flash_attention = False
...@@ -1861,7 +1861,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1861,7 +1861,7 @@ class DotProductAttention(torch.nn.Module):
self.use_fused_attention = ( self.use_fused_attention = (
int(os.getenv("NVTE_FUSED_ATTN", "1")) int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= (8, 0)
) )
assert ( assert (
...@@ -2091,9 +2091,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -2091,9 +2091,12 @@ class DotProductAttention(torch.nn.Module):
# Filter: Device and dimensions. # Filter: Device and dimensions.
if key_layer.shape[-1] > 64: if key_layer.shape[-1] > 64:
if self.device_compute_capability in (8.6, 8.7): if self.device_compute_capability in ((8, 6), (8, 7)):
use_flash_attention = False use_flash_attention = False
elif not _flash_attn_2_available and self.device_compute_capability == 8.9: elif (
not _flash_attn_2_available
and self.device_compute_capability == (8, 9)
):
use_flash_attention = False use_flash_attention = False
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
......
...@@ -22,9 +22,9 @@ __all__ = ["fp8_autocast"] ...@@ -22,9 +22,9 @@ __all__ = ["fp8_autocast"]
def check_fp8_support() -> Tuple[bool, str]: def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above if get_device_compute_capability() >= (9, 0): # hopper and above
return True, "" return True, ""
if get_device_compute_capability() < 8.9: # pre-ada if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution." return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103: if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
......
...@@ -8,11 +8,10 @@ from typing import Any, Callable, Optional, Tuple ...@@ -8,11 +8,10 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
def get_device_compute_capability() -> float: def get_device_compute_capability() -> Tuple[int, int]:
"""Returns the cuda compute capability of current GPU""" """CUDA compute capability of current GPU"""
major = torch.cuda.get_device_properties(torch.cuda.current_device()).major props = torch.cuda.get_device_properties(torch.cuda.current_device())
minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor return (props.major, props.minor)
return major + minor / 10
def attention_mask_func( def attention_mask_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment