import math import random from typing import Optional, Tuple import pytest import torch import triton from sgl_kernel.flash_mla import ( flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) skip_condition = torch.cuda.get_device_capability() < (10, 0) # ================ prefill usage ================ # S_Q_PREFILL = [1, 62] KV_TOPK_PREFILL = [ # Regular shapes (128, 128), (256, 256), (512, 512), # Irregular shapes (592, 128), (1840, 256), (1592, 384), (1521, 512), # Irregular shapes with OOB TopK (95, 128), (153, 256), (114, 384), ] # ================= decode usage ================= # B_DECODE = [1, 2, 6, 64] S_Q_DECODE = [1, 2, 4] S_K_DECODE = [20, 140, 4096] IS_VARLEN = [False, True] CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)] DTYPE = [torch.float16, torch.bfloat16] def quantize_k_cache( input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) dv: int, tile_size: int = 128, ) -> torch.Tensor: """ Quantize the k-cache Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md """ assert dv % tile_size == 0 num_tiles = dv // tile_size num_blocks, block_size, h_k, d = input_k_cache.shape assert h_k == 1 input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_elem_size = input_k_cache.element_size() result = torch.empty( (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device, ) result_k_nope_part = result[..., :dv] result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32) result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype) result_k_rope_part[:] = input_k_cache[..., dv:] for tile_idx in range(0, num_tiles): cur_scale_factors_inv = ( torch.abs( input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] ) .max(dim=-1) .values / 448.0 ) # [num_blocks, block_size] result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] cur_quantized_nope = ( input_k_cache[ ..., tile_idx * tile_size : (tile_idx + 1) * tile_size ].float() / cur_scale_factors_inv.float() ).to(torch.float8_e4m3fn) result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( cur_quantized_nope ) result = result.view(num_blocks, block_size, 1, -1) return result def dequantize_k_cache( quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) dv: int = 512, tile_size: int = 128, d: int = 576, ) -> torch.Tensor: """ De-quantize the k-cache """ assert dv % tile_size == 0 num_tiles = dv // tile_size num_blocks, block_size, h_k, _ = quant_k_cache.shape assert h_k == 1 result = torch.empty( (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device ) quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) input_nope = quant_k_cache[..., :dv] input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32) input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16) result[..., dv:] = input_rope for tile_idx in range(0, num_tiles): cur_nope = input_nope[ ..., tile_idx * tile_size : (tile_idx + 1) * tile_size ].to(torch.float32) cur_scales = input_scale[..., tile_idx].unsqueeze(-1) result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( cur_nope * cur_scales ) result = result.view(num_blocks, block_size, 1, d) return result def cdiv(x: int, y: int): return (x + y - 1) // y def get_window_size(causal, window): if window > 0: window_size = (window - 1, 0) if causal else (window - 1, window - 1) else: window_size = (-1, -1) return window_size def get_attn_bias(s_q, s_k, causal, window): attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda") if causal: temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( diagonal=s_k - s_q ) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) if window > 0: temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( diagonal=s_k - s_q - window ) attn_bias.masked_fill_(temp_mask, float("-inf")) temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( diagonal=s_k - s_q + window - 1 ) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) return attn_bias def sdpa(query, key, value, attn_bias, softmax_scale=None): query = query.float().transpose(-3, -2) key = key.float().transpose(-3, -2) value = value.float().transpose(-3, -2) key = key.repeat_interleave(h // h_k, dim=-3) value = value.repeat_interleave(h // h_k, dim=-3) if softmax_scale is None: softmax_scale = query.shape[-1] ** (-0.5) attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale attn_weight += attn_bias lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) return attn_weight.to(query.dtype) @ value, lse def sdpa_checkpoint(*args, **kwargs): return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) def reference_torch_prefill( s_q, s_kv, topk, indices, q, kv, sm_scale: float ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) indices = indices[0, :, 0, :] # [s_q, topk] invalid_indices_mask = (indices < 0) | (indices >= s_kv) qs = q[0, :, :, :].float() # [s_q, h_q, d_qk] kvs = kv[0, :, 0, :].float() # [s_kv, d_qk] kvs = torch.index_select( kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten() ).view( s_q, topk, 576 ) # [s_q, topk, d_qk] attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf")) attn_score *= sm_scale * math.log2(math.e) max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] result = attn_score @ kvs[:, :, :512] return (max_logits, lse, result) def reference_torch_decode( cache_seqlens: torch.Tensor, # [batch_size] block_table: torch.Tensor, # [batch_size, ?] q: torch.Tensor, # [batch_size, s_q, h_q, d] blocked_k: torch.Tensor, # [?, block_size, h_kv, d] dv: int, is_causal: bool, indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk] ) -> Tuple[torch.Tensor, torch.Tensor]: """ A reference implementation in PyTorch """ def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda") for i in range(s_q): cur_indices = indices[i] valid_indices = cur_indices[cur_indices != -1] mask[i, valid_indices] = True return mask def scaled_dot_product_attention( batch_idx: int, query: torch.Tensor, # [h_q, s_q, d] kv: torch.Tensor, # [h_kv, s_k, d] dv: int, is_causal, indices: Optional[torch.Tensor], # [s_q, topk] ) -> Tuple[torch.Tensor, torch.Tensor]: h_q = query.size(0) h_kv = kv.size(0) s_q = query.shape[-2] s_k = kv.shape[-2] query = query.float() kv = kv.float() if h_kv != 1: kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv[kv != kv] = 0.0 attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] if (is_causal and query.size(1) > 1) or indices is not None: mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda") if is_causal: assert indices is None mask = mask.tril(diagonal=s_k - s_q) if indices is not None: mask &= get_topk_attn_mask(s_q, s_k, indices) attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda") attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_weight += attn_bias.to(q.dtype) attn_weight /= math.sqrt(query.size(-1)) lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] # Correct for q tokens which has no attendable k lonely_q_mask = lse == float("-inf") output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 lse[lonely_q_mask] = float("+inf") return output, lse b, s_q, h_q, d = q.size() block_size = blocked_k.size(1) h_kv = blocked_k.size(2) cache_seqlens_cpu = cache_seqlens.cpu() out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda") lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda") for i in range(b): cur_len = cache_seqlens_cpu[i].item() cur_num_blocks = cdiv(cur_len, block_size) cur_block_indices = block_table[i][0:cur_num_blocks] cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] cur_out, cur_lse = scaled_dot_product_attention( i, q[i].transpose(0, 1), cur_kv.transpose(0, 1), dv, is_causal, indices[i] if indices is not None else None, ) out_ref[i] = cur_out.transpose(0, 1) lse_ref[i] = cur_lse out_ref = out_ref.to(torch.bfloat16) return out_ref, lse_ref @pytest.mark.parametrize("s_q", S_Q_PREFILL) @pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL) @torch.inference_mode() def test_flashmla_prefill( s_q: int, kv_topk: Tuple[int, int], ): torch.cuda.empty_cache() q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10 kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10 q.clamp_(-10, 10) kv.clamp_(-10, 10) indices = torch.full( (1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda" ) for s in range(s_q): # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention near_mask = ( torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31 ) cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]] cur_indices[near_mask] = torch.randint( max(0, kv_topk[0] - 20000), kv_topk[0] - 1, (near_mask.sum().item(),), device="cuda", ) if len(cur_indices) < kv_topk[1]: cur_indices = torch.cat( [ cur_indices, torch.full( (kv_topk[1] - len(cur_indices),), 2147480000, device="cuda" ), ] ) cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")] indices[0, s, 0] = cur_indices indices = indices.to(q.device) sm_scale = 1 / math.sqrt(576) torch.cuda.synchronize() ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd( q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale ) ans_out, ans_max_logits, ans_lse = ( ans_out.float(), ans_max_logits.float(), ans_lse.float(), ) torch.cuda.synchronize() ref_max_logits, ref_lse, ref_out = reference_torch_prefill( s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale ) torch.cuda.synchronize() torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128) torch.testing.assert_close( ans_max_logits, ref_max_logits, atol=1e-6, rtol=2.01 / 65536, ) torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536) @pytest.mark.parametrize("b", B_DECODE) @pytest.mark.parametrize("s_q", S_Q_DECODE) @pytest.mark.parametrize("s_k", S_K_DECODE) @pytest.mark.parametrize("is_varlen", IS_VARLEN) @pytest.mark.parametrize("causal_topk", CAUSAL_TOPK) @pytest.mark.parametrize("dtype", DTYPE) @torch.inference_mode() def test_flash_mla_decode( b: int, s_q: int, s_k: int, is_varlen: bool, causal_topk: Tuple[bool, Optional[int]], dtype: torch.dtype, ): d = 576 dv = 512 block_size = 64 h_q = 128 h_kv = 1 is_causal = causal_topk[0] topk = causal_topk[1] # Generating test data torch.cuda.synchronize() cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu") if is_varlen: for i in range(b): cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q) max_seqlen = cache_seqlens_cpu.max().item() max_seqlen_pad = cdiv(max_seqlen, 256) * 256 cache_seqlens = cache_seqlens_cpu.cuda() q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda") q.clamp_(min=-1.0, max=1.0) block_table = torch.arange( b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda" ).view(b, max_seqlen_pad // block_size) block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1) blocked_k = ( torch.randn( block_table.numel(), block_size, h_kv, d, dtype=torch.bfloat16, device="cuda", ) / 10 ) blocked_k.clamp_(min=-1.0, max=1.0) if topk is None: for i in range(b): cur_len = cache_seqlens_cpu[i].item() cur_num_blocks = cdiv(cur_len, block_size) blocked_k[block_table[i][cur_num_blocks:]] = float("nan") if cur_len % block_size != 0: blocked_k[block_table[i][cur_num_blocks - 1]][ cur_len % block_size : ] = float("nan") block_table[i][cur_num_blocks:] = 2147480000 abs_indices = None indices_in_kvcache = None else: block_table_cpu = block_table.cpu() abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu") indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu") for i in range(b): # Generate indices for j in range(s_q): cur_abs_indices = torch.randperm( int(cache_seqlens_cpu[i].item()), device="cpu" )[:topk] cur_blocked_indices = block_table_cpu[ i, cur_abs_indices // block_size ] * block_size + (cur_abs_indices % block_size) if len(cur_abs_indices) < topk: pad_len = topk - len(cur_abs_indices) cur_abs_indices = torch.cat( [cur_abs_indices, torch.full((pad_len,), -1, device="cpu")] ) cur_blocked_indices = torch.cat( [cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")] ) # Mask KV perm = torch.randperm(topk, device="cpu") cur_abs_indices = cur_abs_indices[perm] cur_blocked_indices = cur_blocked_indices[perm] abs_indices[i, j, :] = cur_abs_indices indices_in_kvcache[i, j, :] = cur_blocked_indices # Mask nonused KV as NaN all_indices = indices_in_kvcache.flatten().tolist() all_indices = list(set(all_indices)) if -1 in all_indices: all_indices.remove(-1) all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu") blocked_k = blocked_k.view(-1, h_kv, d) nonused_indices_mask = torch.ones( blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu" ) nonused_indices_mask[all_indices] = False blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k = blocked_k.view(-1, block_size, h_kv, d) abs_indices = abs_indices.to(q.device) indices_in_kvcache = indices_in_kvcache.to(q.device) is_fp8 = topk is not None if is_fp8: # The quantization error may be too large to be distinguished from wrong kernels # So we quantize and de-quantize kv-cache here to mitigate quantization error blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128) blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized) blocked_k = blocked_k_dequantized # Get schedule metadata torch.cuda.synchronize() tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk ) torch.cuda.synchronize() out_ans, lse_ans = flash_mla_with_kvcache( q, blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=is_causal, is_fp8_kvcache=is_fp8, indices=indices_in_kvcache, ) out_ref, lse_ref = reference_torch_decode( cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices ) torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128) torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536) if __name__ == "__main__": pytest.main([__file__])