# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tests for fused attention""" import sys from enum import Enum from dataclasses import dataclass from functools import partial from math import sqrt import jax import jax.numpy as jnp import pytest from flax.linen import combine_masks from flax.linen import make_attention_mask from flax.linen.dtypes import promote_dtype from jax import Array from jax import value_and_grad, jit from jax.typing import ArrayLike, DTypeLike from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine_jax import NVTE_Fused_Attn_Backend from utils import assert_allclose @pytest.fixture(autouse=True, scope='function') def clear_live_arrays(): """ Clear all live arrays to keep the resource clean """ # Calling customcalls before jax may cause CUDA uninitialize error _ = jnp.zeros(0) yield for arr in jax.live_arrays(): arr.delete() def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, bias: ArrayLike, mask: ArrayLike, deterministic: bool, scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike, dtype: DTypeLike) -> Array: """ Similar to flax.linen.dot_product_attention but with GQA support """ query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype) dtype = query.dtype b, s_q, h_q, d = query.shape _, s_kv, h_kv, _ = key.shape assert (h_q % h_kv == 0) and (h_q >= h_kv) num_groups = h_q // h_kv grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d)) # logits with shape (b, h_kv, num_groups, s_q, s_kv) logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key) if bias is not None: # reshape logits without groups logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv)) # apply post-scale bias logits = logits + bias # reshape logits back to original logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv)) if mask is not None: if mask.ndim != logits.ndim: mask = jnp.expand_dims(mask, axis=-3) logits = jnp.where(mask, logits, jnp.finfo(dtype).min) softmax_out = jax.nn.softmax(logits).astype(dtype) if not deterministic and dropout_rate > 0.: keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape) multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) softmax_out = softmax_out * multiplier context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value) context = jnp.reshape(context, query.shape) return context def is_causal_mask(mask: AttnMaskType): """ Check if the mask is a causal mask """ return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: """ Create padded causal mask """ q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0) return combine_masks(causal_mask, padding_mask) def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): """ JAX native dot product attention implementation """ attn_mask_type = kwargs['attn_mask_type'] if is_causal_mask(attn_mask_type): mask = make_decoder_mask(q_token, kv_token) else: mask = make_attention_mask(q_token > 0, kv_token > 0) output = general_dot_product_attention(query, key, value, bias=bias, mask=mask, deterministic=not kwargs['is_training'], scale_factor=kwargs['scaling_factor'], dropout_rate=kwargs['dropout_probability'], dropout_rng=dropout_rng, dtype=jnp.float32) return output.astype(query.dtype) def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): """ TE customcall dot product attention implementation """ attn_mask_type = kwargs['attn_mask_type'] if is_causal_mask(attn_mask_type): mask = make_decoder_mask(q_token, kv_token) else: mask = make_attention_mask(q_token > 0, kv_token > 0) # mask invert mask = jnp.logical_not(mask) qkv_layout = kwargs.pop('qkv_layout') match qkv_layout: case QKVLayout.BS3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) qkv = jnp.concatenate((query, key, value), axis=-3) return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) case QKVLayout.BSHD_BS2HD: key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) kv = jnp.concatenate((key, value), axis=-3) return cross_fused_attn(query, kv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) case QKVLayout.BSHD_BSHD_BSHD: return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(query.dtype) class BiasShape(Enum): BIAS_1HSS = '1HSS' BIAS_B1SS = 'B1SS' BIAS_BHSS = 'BHSS' BIAS_11SS = '11SS' @dataclass class FusedAttnRunner: """ Fused attention runner """ batch_size: int max_seqlen_q: int max_seqlen_kv: int num_heads_q: int num_heads_kv: int head_dim: int attn_bias_type: AttnBiasType attn_mask_type: AttnMaskType dropout_prob: float dtype: DTypeLike is_training: bool qkv_layout: QKVLayout bias_shape: BiasShape def _check_configs(self): if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv: pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.") if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") self.backend = FusedAttnHelper( self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value, self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") if self.bias_shape != BiasShape.BIAS_1HSS: if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS: pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.") elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " "the F16_arbitrary_seqlen backend.") def _setup_inputs(self): self._check_configs() key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) if self.bias_shape == BiasShape.BIAS_1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_BHSS: bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.xfail("PyTest attempted to use an unrecognized bias layout!") self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.) if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.bias_shape == BiasShape.BIAS_1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for # an arbitrary mask where (True/False -> 0/-Inf) cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) max_id = min(self.max_seqlen_q, self.max_seqlen_kv) seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() for i in range(1, len(seq_id)): self.bias = \ self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.) else: self.bias = None if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pad_ratio = 0.0 else: pad_ratio = 0.3 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) valid_len = max_seqlen - pad_len tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) return valid_len, tokens self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1. / sqrt(self.head_dim) def test_forward(self): """ Test forward without JIT """ self._setup_inputs() args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] kwargs = { 'attn_bias_type': self.attn_bias_type, 'attn_mask_type': self.attn_mask_type, 'scaling_factor': self.scaling_factor, 'dropout_probability': self.dropout_prob, 'is_training': self.is_training, 'qkv_layout': self.qkv_layout, } # Convert the outputs to float32 for the elementwise comparison primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) if self.is_training and self.dropout_prob > 0.: return primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1) reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1) assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) def test_backward(self): """ Test value_and_grad with JIT, which includes both forward and backward """ self._setup_inputs() def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.valid_len_q * self.num_heads_q if is_causal_mask(self.attn_mask_type): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] kwargs = { 'attn_bias_type': self.attn_bias_type, 'attn_mask_type': self.attn_mask_type, 'scaling_factor': self.scaling_factor, 'dropout_probability': self.dropout_prob, 'is_training': self.is_training, 'qkv_layout': self.qkv_layout, } # We can compute dBias only for the [1, h, s, s] layout arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( value_and_grad( lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args, **kwargs), arg_nums)) jitted_reference = jit( value_and_grad( lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), arg_nums)) primitive_out, primitive_dgrad = jitted_primitive(*args) reference_out, reference_dgrad = jitted_reference(*args) # Skip elementwise comparison when dropout enabled if self.dropout_prob > 0.0: return assert_allclose(primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype) def check_dqkv(primitive, reference, valid_len): primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1) assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) # Convert the outputs to float32 for the elementwise comparison primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3]) reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3]) check_dqkv(primitive_dq, reference_dq, self.valid_len_q) check_dqkv(primitive_dk, reference_dk, self.valid_len_kv) check_dqkv(primitive_dv, reference_dv, self.valid_len_kv) if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: primitive_dbias = jnp.float32(primitive_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3]) assert_allclose( primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]), dtype=self.dtype) # dbias padded part assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], reference_dbias[..., self.valid_len_q:, self.valid_len_kv:], dtype=self.dtype) # dbias valid part assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv], reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], dtype=self.dtype) @pytest.mark.parametrize('bias_shape', [ pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'), pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'), pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'), pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'), ]) @pytest.mark.parametrize('attn_bias_type', [ pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'), pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'), ]) @pytest.mark.parametrize('attn_mask_type', [ pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'), pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'), pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'), ]) @pytest.mark.parametrize('qkv_layout', [ pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'), pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'), pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'), ]) @pytest.mark.parametrize('dtype', [ pytest.param(jnp.bfloat16, id="BF16"), pytest.param(jnp.float16, id="FP16") ]) @pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[ pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA') ]) @pytest.mark.parametrize('dropout_prob', [ pytest.param(0.0, id="DROP_0.0"), pytest.param(0.1, id="DROP_0.1") ]) @pytest.mark.parametrize('is_training', [ pytest.param(True, id='TRAINING'), pytest.param(False, id='INFERENCE'), ]) class TestFusedAttn: """ Fused attention tester """ @staticmethod def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, is_training, qkv_layout, bias_shape): """ Test forward with parameterized configs """ runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, is_training, qkv_layout, bias_shape) runner.test_forward() @staticmethod def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, is_training, qkv_layout, bias_shape): """ Test backward with parameterized configs """ if not is_training: pytest.skip("Backward pass does not support inference.") runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, True, qkv_layout, bias_shape) runner.test_backward()