Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
import torch
def _get_new_non_contiguous_tensor_shape(shape):
"""
Get the expanded shape for a non-contiguous tensor.
The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1
"""
return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)]
def gen_non_contiguous_randn_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.randn(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def gen_non_contiguous_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.empty(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:
new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device)
new_tensor[:] = tensor
return new_tensor
import torch
_is_low_precision_mode_stack = []
class LowPrecisionMode:
def __init__(self, enabled: bool = True):
self.enabled = enabled
def __enter__(self):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.append(self.enabled)
def __exit__(self, exc_type, exc_value, traceback):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.pop()
def is_low_precision_mode() -> bool:
global _is_low_precision_mode_stack
if len(_is_low_precision_mode_stack) == 0:
return False
return _is_low_precision_mode_stack[-1]
def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting."
if is_low_precision_mode():
tensor_bf16 = tensor.to(torch.bfloat16)
tensor_fp32 = tensor_bf16.to(torch.float32)
return tensor_fp32
else:
return tensor
import os
import functools
colors = {
'RED_FG': '\033[31m',
'GREEN_FG': '\033[32m',
'CYAN_FG': '\033[36m',
'GRAY_FG': '\033[90m',
'YELLOW_FG': '\033[33m',
'RED_BG': '\033[41m',
'GREEN_BG': '\033[42m',
'CYAN_BG': '\033[46m',
'YELLOW_BG': '\033[43m',
'GRAY_BG': '\033[100m',
'CLEAR': '\033[0m'
}
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
@functools.lru_cache()
def is_using_profiling_tools() -> bool:
"""
Return whether we are running under profiling tools like nsys or ncu
NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it
"""
is_using_nsys = os.environ.get('NSYS_PROFILING_SESSION_ID') is not None
is_using_ncu = os.environ.get('NV_COMPUTE_PROFILER_PERFWORKS_DIR') is not None
is_using_compute_sanitizer = os.environ.get('NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN') is not None
return is_using_nsys or is_using_ncu or is_using_compute_sanitizer
def set_random_seed(seed: int):
import random
import numpy as np
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
class Counter:
def __init__(self):
self.count = 0
def next(self) -> int:
self.count += 1
return self.count - 1
from typing import List import dataclasses
import os
import enum
from typing import List, Optional
import random
import torch import torch
import kernelkit as kk
import flash_mla
def cdiv(x: int, y: int): import quant
return (x+y-1) // y
def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): class TestTarget(enum.Enum):
FWD = 0
DECODE = 1
@dataclasses.dataclass
class ExtraTestParamForDecode:
b: int
is_varlen: bool
have_zero_seqlen_k: bool
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
@dataclasses.dataclass
class TestParam:
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 512
d_v: int = 512
seed: int = -1 # -1: to be filled automatically
check_correctness: bool = True
is_all_indices_invalid: bool = False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647)
num_runs: int = 10
have_attn_sink: bool = False
have_topk_length: bool = False
decode: Optional[ExtraTestParamForDecode] = None
@dataclasses.dataclass
class RawTestParamForDecode:
"""
"Flattened" test parameters for decoding test
In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test.
""" """
Check if two tensors are close enough b: int
h_q: int
s_q: int
h_kv: int
s_kv: int
is_varlen: bool
topk: int
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
have_topk_length: bool = False
enable_attn_sink: bool = True
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
d_qk: int = 576 # Q/K head dim (= dv + RoPE dim)
d_v: int = 512 # V head dim
check_correctness: bool = True
num_runs: int = 10
seed: int = -1
def to_test_param(self) -> TestParam:
return TestParam(
self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v,
self.seed, self.check_correctness,
self.is_all_indices_invalid,
self.num_runs,
self.enable_attn_sink,
self.have_topk_length,
decode = ExtraTestParamForDecode(
self.b, self.is_varlen, self.have_zero_seqlen_k,
self.extra_s_k, self.extra_topk,
self.block_size, self.extra_block_size, self.have_extra_topk_length
)
)
@dataclasses.dataclass
class Testcase:
p: TestParam
dOut: torch.Tensor # [s_q, h_q, d_v]
q: torch.Tensor # [s_q, h_q, d_qk]
kv: torch.Tensor # [s_kv, h_kv, d_qk]
indices: torch.Tensor # [s_q, h_kv, topk]
sm_scale: float
attn_sink: Optional[torch.Tensor] # [h_q]
topk_length: Optional[torch.Tensor] # [s_q]
def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor:
""" """
def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: Generate random permutations in batch
The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds.
Values within each row are unique.
If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`.
"""
assert not torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(True)
perm_range_max = max(int(torch.max(perm_range).item()), perm_size)
rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32)
rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first
res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32)
if len(paddings) == 1:
res[res >= perm_range.view(batch_size, 1)] = paddings[0]
else:
fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32))
res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers)
torch.use_deterministic_algorithms(False)
return res
def generate_testcase(t: TestParam) -> Testcase:
kk.set_random_seed(t.seed)
q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
do.clamp_(-10, 10)
invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647]
indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk)
if t.is_all_indices_invalid:
all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2
indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate)
indices = indices.to(q.device)
attn_sink = None
if t.have_attn_sink:
attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
mask = torch.randn((t.h_q, ), dtype=torch.float32)
attn_sink[mask < -0.5] = float("-inf")
attn_sink[mask > +0.5] = float("+inf")
topk_length = None
if t.have_topk_length:
topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk)
q = kk.non_contiguousify(q)
kv = kk.non_contiguousify(kv)
do = kk.non_contiguousify(do)
indices = kk.non_contiguousify(indices)
return Testcase(
p=t,
dOut=do,
q=q,
kv=kv,
indices=indices,
sm_scale=0.5, # Otherwise dK is too small compared to dV
attn_sink=attn_sink,
topk_length=topk_length
)
@dataclasses.dataclass
class KVScope:
t: TestParam
cache_seqlens: torch.Tensor
block_table: torch.Tensor
blocked_k: torch.Tensor
abs_indices: torch.Tensor
indices_in_kvcache: torch.Tensor
topk_length: Optional[torch.Tensor]
blocked_k_quantized: Optional[torch.Tensor] = None
def quant_and_dequant_(self):
""" """
Calculate the cosine diff between two tensors For FP8 cases, we need to quantize the KV cache for Flash MLA.
Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error
"""
fp8_kvcache_layout = None
if self.t.d_qk == 576:
fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse
elif self.t.d_qk == 512:
assert self.abs_indices is not None
fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse
else:
assert False
self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout)
blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout)
self.blocked_k = blocked_k_dequantized
def get_kvcache_for_flash_mla(self) -> torch.Tensor:
""" """
x, y = x.double(), y.double() Return the quantized blocked_k for Flash MLA
denominator = (x * x + y * y).sum().item() """
if denominator == 0: assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`"
return 0 return self.blocked_k_quantized
sim = 2 * (x * y).sum().item() / denominator
return 1 - sim
assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
ans = ans.clone().to(torch.float) def apply_perm(self, perm: torch.Tensor) -> "KVScope":
ref = ref.clone().to(torch.float) """
Apply a batch permutation to this KVScope. Used for batch-invariance test
# Deal with anomalies """
def deal_with_anomalies(val: float): new_kvscope = KVScope(
ref_mask = (ref == val) if (val == val) else (ref != ref) self.t,
ans_mask = (ans == val) if (val == val) else (ans != ans) self.cache_seqlens[perm],
ref[ref_mask] = 0.0 self.block_table[perm],
ans[ans_mask] = 0.0 self.blocked_k,
if not torch.equal(ref_mask, ans_mask): self.abs_indices[perm],
print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") self.indices_in_kvcache[perm],
return False self.topk_length[perm] if self.topk_length is not None else None,
return True self.blocked_k_quantized
)
return new_kvscope
anomalies_check_passed = True @dataclasses.dataclass
anomalies_check_passed &= deal_with_anomalies(float("inf")) class TestcaseForDecode:
anomalies_check_passed &= deal_with_anomalies(float("-inf")) p: TestParam
anomalies_check_passed &= deal_with_anomalies(float("nan")) q: torch.Tensor # [b, s_q, h_q, d_qk]
attn_sink: Optional[torch.Tensor] # [h_q]
if not anomalies_check_passed: sm_scale: float
return False kv_scope: KVScope
extra_kv_scope: Optional[KVScope]
cos_diff = get_cos_diff(ans, ref)
raw_abs_err = torch.abs(ans-ref) def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) kk.set_random_seed(t.seed)
rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0) assert t.h_q % t.h_kv == 0
abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0) assert t.decode is not None
pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)
q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk))
if not pass_mask.all(): q.clamp_(min=-1.0, max=1.0)
print(f"`{name}` mismatch")
max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item() # type: ignore attn_sink = None
max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item() # type: ignore if t.have_attn_sink:
def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]: attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
result = [] inf_mask = torch.randn((t.h_q, ), dtype=torch.float32)
for size in t.shape[::-1]: attn_sink[inf_mask > 0.5] = float("inf")
result.append(pos % size) attn_sink[inf_mask < -0.5] = float("-inf")
pos = pos // size
assert pos == 0 def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope:
return result[::-1] b = t.decode.b # type: ignore
print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu')
print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") if is_varlen:
print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") for i in range(b):
print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q)
return False
if have_zero_seqlen:
zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen_alignment = 4 * block_size
max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment
cache_seqlens = cache_seqlens_cpu.cuda()
assert max_seqlen_pad % block_size == 0
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32)
if is_all_indices_invalid:
abs_indices.fill_(-1)
else:
abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk)
indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size)
topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None
# Mask nonused KV as NaN
if have_topk_length:
indices_in_kvcache_masked = indices_in_kvcache.clone()
indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1
else:
indices_in_kvcache_masked = indices_in_kvcache
blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[indices_in_kvcache_masked] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk)
block_table = kk.non_contiguousify(block_table)
abs_indices = kk.non_contiguousify(abs_indices)
indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache)
return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)
kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)
kv_scope0.quant_and_dequant_()
if t.decode.extra_topk is not None:
if t.decode.extra_s_k is None:
t.decode.extra_s_k = t.decode.extra_topk*2
if t.decode.extra_block_size is None:
t.decode.extra_block_size = t.decode.block_size
kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)
kv_scope1.quant_and_dequant_()
else: else:
if abs(cos_diff) > cos_diff_tol: assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length
print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") kv_scope1 = None
return False
return True sm_scale = t.d_qk ** -0.55
\ No newline at end of file
q = kk.non_contiguousify(q)
return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1)
def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
assert not return_p_sum
return flash_mla.flash_mla_sparse_fwd(
t.q, t.kv, t.indices,
sm_scale=t.sm_scale,
attn_sink=t.attn_sink,
topk_length=t.topk_length
)
def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):
assert p.decode is not None
return flash_mla.flash_mla_with_kvcache(
t.q,
t.kv_scope.get_kvcache_for_flash_mla(),
None, None, p.d_v,
tile_scheduler_metadata, num_splits,
t.sm_scale, False, True,
t.kv_scope.indices_in_kvcache,
t.attn_sink,
t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None,
t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,
t.kv_scope.topk_length,
t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None
)
@dataclasses.dataclass
class FlopsAndMemVolStatistics:
"""
FLOPs and memory volume statistics for prefilling
"""
fwd_flop: float
fwd_mem_vol: float
def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics:
total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item()
indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv)
if t.topk_length is not None:
indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None]
num_valid_indices = indices_valid_mask.sum().item()
fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v)
fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2
return FlopsAndMemVolStatistics(
fwd_flop,
fwd_mem_vol,
)
@dataclasses.dataclass
class FlopsAndMemVolStatisticsForDecode:
"""
FLOPs and memory volume statistics for decoding
"""
flop: float
mem_vol: float
def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode:
assert p.decode
b = p.decode.b
def get_num_attended_tokens(kv_scope: KVScope) -> int:
topk = kv_scope.indices_in_kvcache.shape[-1]
if kv_scope.topk_length is None:
return b * p.s_q * topk
else:
return int(kv_scope.topk_length.sum().item()) * p.s_q
def get_num_retrieved_tokens(kv_scope: KVScope) -> int:
if kv_scope.topk_length is None:
indices = kv_scope.indices_in_kvcache
else:
indices = kv_scope.indices_in_kvcache.clone()
batch, s_q, topk = indices.shape
mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1)
indices[mask] = -1
num_unique_tokens = indices.unique().numel() # type: ignore
return num_unique_tokens
num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v)
kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache
mem_vol = sum([
2 * b * p.s_q * p.h_q * p.d_qk, # Q
num_retrieved_tokens * kv_token_size, # K
2 * b * p.s_q * p.h_q * p.d_v, # O
])
return FlopsAndMemVolStatisticsForDecode(
compute_flop,
mem_vol
)
def is_no_cooldown() -> bool:
return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y']
import enum
from typing import Tuple
import torch import torch
class FP8KVCacheLayout(enum.Enum):
V32_FP8Sparse = 1
MODEL1_FP8Sparse = 2
def get_meta(self) -> Tuple[int, int, int, int, int]:
# Return: (d, d_nope, d_rope, tile_size, num_tiles)
return {
FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4),
FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7)
}[self]
def _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch.float32) -> torch.Tensor:
return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype)
def quantize_k_cache( def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int, kvcache_layout: FP8KVCacheLayout,
tile_size: int = 128,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Quantize the k-cache Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
""" """
assert dv % tile_size == 0 d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()
num_tiles = dv // tile_size assert input_k_cache.shape[-1] == d
num_blocks, block_size, h_k, d = input_k_cache.shape num_blocks, block_size, h_k, _ = input_k_cache.shape
assert h_k == 1 assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size() input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:
result_k_nope_part = result[..., :dv] bytes_per_token = d_nope + num_tiles*4 + input_elem_size*d_rope
result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32) result = torch.empty((num_blocks, block_size+1, bytes_per_token), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size, :]
result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype) result_k_nope_part = result[..., :d_nope]
result_k_rope_part[:] = input_k_cache[..., dv:] result_k_scale_factor = result[..., d_nope: d_nope + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., d_nope + num_tiles*4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., d_nope:]
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:
bytes_per_token = d_nope + 2*d_rope + num_tiles + 1
size_per_block_padded = (block_size*bytes_per_token + 576-1) // 576 * 576
result = torch.empty((num_blocks, size_per_block_padded), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size*bytes_per_token]
result_k_nope_rope_part = result[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)
result_k_nope = result_k_nope_rope_part[:, :, :d_nope] # [num_blocks, block_size, d_nope]
result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view(input_k_cache.dtype) # [num_blocks, block_size, d_rope]
result_k_scale_factor = result[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles]
result = result.view(num_blocks, block_size, 1, -1) result_k_rope[:] = input_k_cache[..., d_nope:]
return result for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to(torch.float8_e8m0fnu)
cur_scale_factors_inv = cur_scale_factors_inv.view(num_blocks, block_size, 1)
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope[:, :, tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
else:
raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}")
def dequantize_k_cache( def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512, kvcache_layout: FP8KVCacheLayout,
tile_size: int = 128,
d: int = 576
) -> torch.Tensor: ) -> torch.Tensor:
""" """
De-quantize the k-cache De-quantize the k-cache
""" """
assert dv % tile_size == 0 d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1 assert h_k == 1
result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device) result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv] input_nope = quant_k_cache[..., :d_nope]
input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32) input_scale = quant_k_cache[..., d_nope:d_nope + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16) input_rope = quant_k_cache[..., d_nope + num_tiles*4:].view(torch.bfloat16)
result[..., dv:] = input_rope result[..., d_nope:] = input_rope
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32) cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1) cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:
quant_k_cache = quant_k_cache.view(num_blocks, -1) # [num_blocks, ...]
input_nope_rope = quant_k_cache[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)
input_nope = input_nope_rope[:, :, :d_nope]
input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16)
input_scale = quant_k_cache[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles]
result[..., d_nope:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.bfloat16)
cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1)
result[..., tile_idx*tile_size: (tile_idx+1)*tile_size] = cur_nope * cur_scales
else:
raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}")
result = result.view(num_blocks, block_size, 1, d) result = result.view(num_blocks, block_size, 1, d)
return result return result
def abs_indices2indices_in_kvcache(
abs_indices: torch.Tensor, # [b, s_q, topk]
block_table: torch.Tensor, # [b, /]
block_size: int,
) -> torch.Tensor:
"""
Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel
Equivalent to:
b, s_q, topk = abs_indices.shape
indices_in_kvcache = torch.empty_like(abs_indices)
for i in range(b):
cur_abs_indices = abs_indices[i, :, :].clone() # [s_q, topk]
invalid_mask = cur_abs_indices == -1
cur_abs_indices[invalid_mask] = 0
cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size
cur_indices_in_kvcache[invalid_mask] = -1
indices_in_kvcache[i] = cur_indices_in_kvcache
return indices_in_kvcache
"""
b, s_q, topk = abs_indices.shape
_, max_blocks_per_seq = block_table.shape
abs_indices = abs_indices.clone()
invalid_mask = abs_indices == -1
abs_indices[invalid_mask] = 0
real_block_idxs = block_table.view(-1).index_select(0, (abs_indices//block_size + torch.arange(0, b).view(b, 1, 1)*max_blocks_per_seq).view(-1))
indices_in_kvcache = real_block_idxs.view(b, s_q, topk)*block_size + abs_indices%block_size
indices_in_kvcache[invalid_mask] = -1
return indices_in_kvcache
\ No newline at end of file
from typing import Optional, Tuple
import torch
from lib import TestParam, Testcase, TestcaseForDecode, KVScope
def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor:
if lse1 is None:
return lse0
else:
return torch.logsumexp(
torch.stack([
lse0.view(s_q, h_q),
lse1.broadcast_to(s_q, h_q)
], dim=0),
dim=0
)
def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
- o: [s_q, h_q, dv]
- o_fp32: [s_q, h_q, dv]
- max_logits: [s_q, h_q]
- lse: [s_q, h_q]
"""
indices = t.indices.clone().squeeze(1)
if t.topk_length is not None:
mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1) # [s_q, topk]
indices[mask] = -1
invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk]
indices[invalid_mask] = 0
q = t.q.float()
gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float() # [s_q, topk, d_qk]
P = (q @ gathered_kv.transpose(1, 2)) # [s_q, h_q, topk]
P *= t.sm_scale
P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf")
orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q]
max_logits = P.max(dim=-1).values # [s_q, h_q]
lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q)
if not torch.is_inference_mode_enabled():
lse_for_o = lse_for_o.clone()
lse_for_o[lse_for_o == float("-inf")] = float("+inf") # So that corresponding O will be 0
s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))
out = s_for_o @ gathered_kv[..., :p.d_v] # [s_q, h_q, dv]
lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q]
orig_lse[lonely_q_mask] = float("+inf")
return (out.to(torch.bfloat16), out, max_logits, orig_lse)
def ref_sparse_attn_decode(
p: TestParam,
t: TestcaseForDecode
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation of sparse decoding attention in PyTorch
"""
assert p.h_kv == 1
assert p.decode is not None
b = p.decode.b
def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]:
assert kv_scope.indices_in_kvcache is not None
topk = kv_scope.indices_in_kvcache.size(-1)
indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain
gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk) # [b, s_q, topk, d]
invalid_mask = kv_scope.indices_in_kvcache == -1
if kv_scope.topk_length is not None:
invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1)
return gathered_kv, invalid_mask
gathered_kv, invalid_mask = process_kv_scope(t.kv_scope)
if t.extra_kv_scope is not None:
gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope)
gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) # [b, s_q, topk+extra_topk, d]
invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) # [b, s_q, topk+extra_topk]
gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float()
gathered_kv[gathered_kv != gathered_kv] = 0.0
q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk)
attn_weight = q @ gathered_kv.transpose(-1, -2) # [t.b*t.s_q, t.h_q, topk+extra_topk]
attn_weight *= t.sm_scale
attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf")
lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q]
attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1))
output = attn_weight @ gathered_kv[..., :p.d_v] # [t.b*t.s_q, t.h_q, t.dv]
output = output.view(b, p.s_q, p.h_q, p.d_v)
lse = lse.view(b, p.s_q, p.h_q)
# Attention sink
if t.attn_sink is not None:
output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1)
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output.to(torch.bfloat16), lse.transpose(1, 2)
\ No newline at end of file
...@@ -2,14 +2,12 @@ import argparse ...@@ -2,14 +2,12 @@ import argparse
import math import math
import random import random
import dataclasses import dataclasses
from typing import Optional, Tuple from typing import Tuple
import torch import torch
import triton
import kernelkit as kk
import flash_mla import flash_mla
import quant
from lib import cdiv, check_is_allclose
@dataclasses.dataclass @dataclasses.dataclass
class TestParam: class TestParam:
...@@ -18,10 +16,7 @@ class TestParam: ...@@ -18,10 +16,7 @@ class TestParam:
s_k: int # Seq len, or mean seq len if varlen == True s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool is_varlen: bool
is_causal: bool is_causal: bool
is_fp8: bool
topk: Optional[int] = None
test_performance: bool = True test_performance: bool = True
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False have_zero_seqlen_k: bool = False
block_size: int = 64 block_size: int = 64
h_q: int = 128 # Number of q heads h_q: int = 128 # Number of q heads
...@@ -31,7 +26,7 @@ class TestParam: ...@@ -31,7 +26,7 @@ class TestParam:
seed: int = 0 seed: int = 0
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Generate test data from a given configuration Generate test data from a given configuration
Return: [cache_seqlens, q, block_table, blocked_k] Return: [cache_seqlens, q, block_table, blocked_k]
...@@ -53,11 +48,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -53,11 +48,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0 cache_seqlens_cpu[zeros_mask] = 0
max_seqlen = cache_seqlens_cpu.max().item() max_seqlen = int(cache_seqlens_cpu.max().item())
max_seqlen_pad = cdiv(max_seqlen, 256) * 256 max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda() cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(t.b, t.s_q, t.h_q, t.d) q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10
q.clamp_(min=-1.0, max=1.0) q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)
...@@ -65,59 +60,14 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -65,59 +60,14 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10
blocked_k.clamp_(min=-1.0, max=1.0) blocked_k.clamp_(min=-1.0, max=1.0)
if t.topk is None: for i in range(t.b):
for i in range(t.b): cur_len = int(cache_seqlens_cpu[i].item())
cur_len = cache_seqlens_cpu[i].item() cur_num_blocks = kk.cdiv(cur_len, t.block_size)
cur_num_blocks = cdiv(cur_len, t.block_size) blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
blocked_k[block_table[i][cur_num_blocks:]] = float("nan") if cur_len % t.block_size != 0:
if cur_len % t.block_size != 0: blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") block_table[i][cur_num_blocks:] = 2147480000
block_table[i][cur_num_blocks:] = 2147480000 return cache_seqlens, q, block_table, blocked_k
return cache_seqlens, q, block_table, blocked_k, None, None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
for i in range(t.b):
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV
perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
# Fill it with invalid indices if needed
if t.is_all_indices_invalid:
cur_abs_indices.fill_(-1)
cur_blocked_indices.fill_(-1)
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache
def reference_torch( def reference_torch(
...@@ -127,18 +77,10 @@ def reference_torch( ...@@ -127,18 +77,10 @@ def reference_torch(
blocked_k: torch.Tensor, # [?, block_size, h_kv, d] blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int, dv: int,
is_causal: bool, is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
A reference implementation in PyTorch A reference implementation in PyTorch
""" """
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention( def scaled_dot_product_attention(
batch_idx: int, batch_idx: int,
...@@ -146,7 +88,6 @@ def reference_torch( ...@@ -146,7 +88,6 @@ def reference_torch(
kv: torch.Tensor, # [h_kv, s_k, d] kv: torch.Tensor, # [h_kv, s_k, d]
dv: int, dv: int,
is_causal, is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0) h_q = query.size(0)
h_kv = kv.size(0) h_kv = kv.size(0)
...@@ -158,13 +99,10 @@ def reference_torch( ...@@ -158,13 +99,10 @@ def reference_torch(
kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0 kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None: if is_causal and query.size(1) > 1:
mask = torch.ones(s_q, s_k, dtype=torch.bool) mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal: if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q) mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype) attn_weight += attn_bias.to(q.dtype)
...@@ -186,8 +124,8 @@ def reference_torch( ...@@ -186,8 +124,8 @@ def reference_torch(
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b): for i in range(b):
cur_len = cache_seqlens_cpu[i].item() cur_len = int(cache_seqlens_cpu[i].item())
cur_num_blocks = cdiv(cur_len, block_size) cur_num_blocks = kk.cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks] cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention( cur_out, cur_lse = scaled_dot_product_attention(
...@@ -195,12 +133,11 @@ def reference_torch( ...@@ -195,12 +133,11 @@ def reference_torch(
q[i].transpose(0, 1), q[i].transpose(0, 1),
cur_kv.transpose(0, 1), cur_kv.transpose(0, 1),
dv, dv,
is_causal, is_causal
indices[i] if indices is not None else None
) )
out_ref[i] = cur_out.transpose(0, 1) out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16) out_ref = out_ref.to(q.dtype)
return out_ref, lse_ref return out_ref, lse_ref
...@@ -211,58 +148,42 @@ def test_flash_mla(t: TestParam): ...@@ -211,58 +148,42 @@ def test_flash_mla(t: TestParam):
# Generating test data # Generating test data
torch.cuda.synchronize() torch.cuda.synchronize()
cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t) cache_seqlens, q, block_table, blocked_k, = generate_test_data(t)
if t.is_fp8: tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128)
blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
cache_seqlens,
t.s_q * t.h_q // t.h_kv,
t.h_kv,
t.h_q,
t.is_fp8,
t.topk
)
torch.cuda.synchronize()
def run_flash_mla(): def run_flash_mla():
return flash_mla.flash_mla_with_kvcache( return flash_mla.flash_mla_with_kvcache(
q, q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore blocked_k,
block_table, block_table,
cache_seqlens, cache_seqlens,
t.dv, t.dv,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
causal=t.is_causal, causal=t.is_causal
is_fp8_kvcache=t.is_fp8,
indices=indices_in_kvcache
) )
out_ans, lse_ans = run_flash_mla() out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) is_correct = True
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
assert is_correct
if t.test_performance: if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel")
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
mean_attended_seqlens = cache_seqlens.float().mean().item()
compute_volume_flop = t.b * t.h_q * t.s_q * sum([ compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2 * t.d * mean_attended_seqlens, # Q * K^T 2 * t.d * mean_attended_seqlens, # Q * K^T
2 * mean_attended_seqlens * t.dv, # attention * V 2 * mean_attended_seqlens * t.dv, # attention * V
]) ])
q_elem_size = torch.bfloat16.itemsize q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize kv_token_size = t.d * torch.bfloat16.itemsize
memory_volume_B = t.b * sum([ memory_volume_B = t.b * sum([
t.s_q * t.h_q * (t.d * q_elem_size), # Q t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q * t.h_q * (t.dv * q_elem_size), # Output t.s_q * t.h_q * (t.dv * q_elem_size), # Output
]) ])
achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_tflops = compute_volume_flop / time_usage / 1e12
...@@ -277,54 +198,39 @@ def main(torch_dtype): ...@@ -277,54 +198,39 @@ def main(torch_dtype):
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
cc_major, cc_minor = torch.cuda.get_device_capability()
assert cc_major == 9, "Dense MLA decoding is only supported on sm90 (Hopper) currently."
correctness_cases = [ correctness_cases = [
TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False) TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv)
for b in [1, 2, 6, 64] for b in [1, 2, 6, 64]
for s_q in [1, 2, 4] for s_q in [1, 2, 4]
for s_k in [20, 140, 4096] for s_k in [20, 140, 4096]
for h_q in [1, 3, 9, 63, 64, 126, 128]
for h_kv in [1, 2, 3, 8]
for is_varlen in [False, True] for is_varlen in [False, True]
for is_causal in [False, True] for is_causal in [False, True]
for (is_fp8, topk) in [ if h_q % h_kv == 0
(False, None),
(True, 128),
(True, 2048)
]
if not (is_causal and topk is not None)
] ]
corner_cases = [ corner_cases = [
# Cases where all topk indices are invalid
TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True)
for topk in [128, 2048, 4096]
] + [
# Cases where some kv cache have zero length # Cases where some kv cache have zero length
TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True) TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv)
for (is_causal, is_fp8, topk) in [ for h_q in [1, 3, 9, 63, 64, 126, 128]
(False, False, None), for h_kv in [1, 2, 3, 8]
(True, False, None), for is_causal in [False, True]
(False, True, 128), if h_q % h_kv == 0
(False, True, 2048),
]
] ]
performance_cases = [ performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True) TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True)
for (is_causal, is_fp8, topk) in [ for is_causal in [False, True]
(False, False, None),
(True, False, None),
(False, True, 2048),
]
for s_q in [1, 2] for s_q in [1, 2]
for s_k in [4096, 8192, 16384, 32768] for s_k in [4096, 8192, 16384, 32768]
] ]
testcases = correctness_cases + corner_cases + performance_cases testcases = correctness_cases + corner_cases + performance_cases
# Prune out unsupported cases
cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases: for testcase in testcases:
test_flash_mla(testcase) test_flash_mla(testcase)
...@@ -345,4 +251,4 @@ if __name__ == "__main__": ...@@ -345,4 +251,4 @@ if __name__ == "__main__":
if args.dtype == "fp16": if args.dtype == "fp16":
torch_dtype = torch.float16 torch_dtype = torch.float16
main(torch_dtype) main(torch_dtype)
\ No newline at end of file
import math
import time
from typing import Tuple
import random
import dataclasses
import torch
import triton
from flash_mla import flash_mla_sparse_fwd
from lib import check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 576
d_v: int = 512
seed: int = 0
check_correctness: bool = True
benchmark: bool = True
@dataclasses.dataclass
class Testcase:
t: TestParam
q: torch.Tensor
kv: torch.Tensor
indices: torch.Tensor
def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32)
for b in range(t.b):
for s in range(t.s_q):
for h in range(t.h_kv):
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
indices[b, s, h] = cur_indices
indices = indices.to(q.device)
return Testcase(
t=t,
q=q,
kv=kv,
indices=indices
)
def get_flop(p: TestParam) -> float:
flop = 2 * sum([
p.h_q * p.d_qk * p.topk,
p.h_q * p.d_v * p.topk
]) * p.b * p.s_q
return flop
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :p.d_v]
return (max_logits, lse, result)
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
assert p.b == 1
t = generate_testcase(p)
sm_scale = 1 / math.sqrt(p.d_qk)
torch.cuda.synchronize()
def run_ans():
return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize()
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale)
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62
]
]
corner_cases = [
# In these cases, some blocks may not have any valid topk indices
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
]
performance_cases = [
TestParam(1, s_q, s_kv, topk, h_q=128)
for s_q in [4096]
for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072]
for topk in [2048]
]
testcases = correctness_cases + corner_cases + performance_cases
failed_cases = []
for test in testcases:
if test.benchmark:
time.sleep(0.2)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
import time
import dataclasses
from typing import Tuple, List, Dict, Optional
import copy
import rich.console
import rich.table
import torch
import kernelkit as kk
import flash_mla
import lib
from lib import TestParam
from lib import RawTestParamForDecode as RawTestParam
import ref
"""
Generate testcase for unit test
"""
def gen_testcase() -> List[RawTestParam]:
correctness_cases = []
corner_cases = []
for d_qk in [576, 512]:
for have_extra_k in ([False, True] if d_qk == 512 else [False]):
for have_extra_topk_len in ([False, True] if have_extra_k else [False]):
for have_topk_len in ([False, True] if d_qk == 512 else [False]):
for h_q in [64, 128]:
cur_correctness_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
have_topk_length=have_topk_len,
enable_attn_sink=True,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0)
for (s_k, topk, block_size) in [
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [1, 3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
]
correctness_cases.extend(cur_correctness_cases)
cur_corner_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
is_all_indices_invalid=is_all_indices_invalid,
have_zero_seqlen_k=have_zero_seqlen_k,
have_topk_length=have_topk_len,
enable_attn_sink=enable_attn_sink,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0,
)
for (s_k, topk, block_size) in [
(512, 64, 61),
(650, 576, 53),
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 61),
(650, 576, 53),
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
for is_all_indices_invalid in [True, False]
for have_zero_seqlen_k in [True, False]
for enable_attn_sink in [True, False]
if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink)
]
corner_cases.extend(cur_corner_cases)
base_and_bszs = [
# V3.2
(RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]),
# MODEL1 CONFIG1
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG2
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG3
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG4
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
]
performance_cases = [
# Production cases
dataclasses.replace(base, b=b)
for base, bszs in base_and_bszs
for b in bszs
] + [
# Peak perf cases
RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)
for h_q in [64, 128]
for d_qk in [512, 576]
]
return correctness_cases + corner_cases + performance_cases
@dataclasses.dataclass
class Result:
is_correct: bool
compute_memory_ratio: float
time_usage_per_us: float
splitkv_time_usage_us: float
combine_time_usage_us: float
achieved_tflops: float
achieved_gBps: float
_counter = kk.Counter()
@torch.inference_mode()
def test_flash_mla(p: TestParam) -> Result:
if p.seed == -1:
global _counter
p.seed = _counter.next()
assert p.decode
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase_for_decode(p)
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
def run_decode():
return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None)
# We first run the kernel once to generate output data for the correctness test
# We must do this first, otherwise when allocating tensors for storing answers,
# it may re-use memory that contains the correct answer, leading to false positives
if p.check_correctness:
torch.cuda.synchronize()
out_ans, lse_ans = run_decode()
torch.cuda.synchronize()
# torch.set_printoptions(profile='full')
# print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7])
# We run the performance test before generating the answer for the correctness test to avoid interference
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
if p.num_runs == 0:
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
else:
result = kk.bench_kineto(run_decode, p.num_runs)
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages
kernel_time_usages_us: Dict[str, Optional[float]] = {}
def pick_kernel_time_usage(kernel_name: str):
t = [kernel_name in s for s in result.get_kernel_names()]
if any(t):
assert sum(t) == 1
kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6
else:
kernel_time_usages_us[kernel_name] = None
pick_kernel_time_usage(splitkv_kernel_name)
pick_kernel_time_usage(combine_kernel_name)
# Get E2E time usages
def have_kernel(name: str):
return kernel_time_usages_us[name] is not None
if kk.is_using_profiling_tools():
e2e_time_usage_us = 1e6
else:
assert have_kernel(splitkv_kernel_name)
if have_kernel(combine_kernel_name):
e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6
else:
e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name]
assert e2e_time_usage_us is not None
flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)
e2e_time_usage_s = e2e_time_usage_us / 1e6
theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol
achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12
achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9
def print_kernel_time_usage(name: str, short_name: str):
if kernel_time_usages_us[name] is not None:
print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')
print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')
print(f'Time (per): {e2e_time_usage_us:.1f} us')
print_kernel_time_usage(splitkv_kernel_name, "Splitkv")
print_kernel_time_usage(combine_kernel_name, "Combine")
print(f'TFlops: {achieved_tflops:.1f}')
print(f'GB/s: {achieved_gBps:.0f}')
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)
is_correct = True
if p.check_correctness:
torch.cuda.synchronize()
with torch.profiler.record_function("reference_flash_mla"):
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct
performance_result.is_correct = is_correct
return performance_result
def main():
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
torch.set_num_threads(32)
raw_testcases = gen_testcase()
testcases = [t.to_test_param() for t in raw_testcases]
print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}")
is_no_cooldown = lib.is_no_cooldown()
num_testcases_len = len(str(len(testcases)))
failed_cases = []
results: List[Tuple[TestParam, Result]] = []
for testcase_idx, testcase in enumerate(testcases):
if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3) # Cooldown
print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='')
result = test_flash_mla(testcase)
results.append((testcase, result))
if not result.is_correct:
failed_cases.append(testcase)
import sys
sys.exit(1)
console = rich.console.Console(width=120)
table = rich.table.Table(show_header=True, header_style="bold cyan")
table.add_column("topk")
table.add_column("Bsz")
table.add_column("h_q&k")
table.add_column("sq")
table.add_column("sk")
table.add_column("d_qk")
table.add_column("Feats")
table.add_column("C/M")
table.add_column("TFlops")
table.add_column("GBps")
table.add_column("us")
table.add_column(" ")
for testcase, result in results:
assert testcase.decode
topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}"
table.add_row(
topk_str,
str(testcase.decode.b),
f"{testcase.h_q:3d} {testcase.h_kv}",
str(testcase.s_q),
str(testcase.s_kv),
str(testcase.d_qk),
" V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length],
f"{result.compute_memory_ratio:3.0f}",
f"{result.achieved_tflops:3.0f}",
f"{result.achieved_gBps:4.0f}",
f"{result.time_usage_per_us:4.1f}",
"" if result.is_correct else "X"
)
console.print(table)
def geomean(l) -> float:
import numpy
return numpy.exp(numpy.mean(numpy.log(l)))
num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True)
num_correctness_cases = sum([1 for t in testcases if t.check_correctness])
if num_correct_testcases == num_correctness_cases:
print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
else:
print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
for t in failed_cases:
print(f"\t{t},")
valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1]
if len(valid_achieved_tflops) > 0:
achieved_tflops_geomean = geomean(valid_achieved_tflops) # > 0.1 to prune out correctness cases
print(f"TFlops geomean: {achieved_tflops_geomean:.1f}")
if __name__ == "__main__":
main()
import time
import sys
import torch
import kernelkit as kk
from lib import TestParam
import lib
import ref
_counter = kk.Counter()
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
if p.seed == -1:
global _counter
p.seed = _counter.next()
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase(p)
torch.cuda.synchronize()
def run_prefill():
return lib.run_flash_mla_sparse_fwd(p, t, False)
prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill()
torch.cuda.synchronize()
if p.num_runs > 0:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd")
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")
if p.check_correctness:
torch.cuda.synchronize()
ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t)
ref_lse[ref_lse == float("-inf")] = float("+inf")
torch.cuda.synchronize()
is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62, 213
]
]
correctness_cases_with_features = [
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [62, 213]
for have_sink_lse in [False, True]
for have_attn_sink in [False, True]
for have_topk_length in [False, True]
]
corner_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_q, s_kv, topk in [
(1, 128, 128),
(1, 256, 256),
(1234, 4321, 4096),
(4096, 2048, 2048)
]
] + [
# In these cases, some blocks may not have any valid topk indices
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
] + [
# In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape
TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
]
performance_case_templates = [
# V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
# MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]),
]
performance_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)
for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
for s_q in [4096]
for s_kv in s_kv_list
]
testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases
is_no_cooldown = lib.is_no_cooldown()
failed_cases = []
for test in testcases:
if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
sys.exit(1)
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
...@@ -5,7 +5,7 @@ from torch.utils.checkpoint import checkpoint ...@@ -5,7 +5,7 @@ from torch.utils.checkpoint import checkpoint
import triton import triton
from flash_mla import flash_attn_varlen_func from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose from kernelkit import check_is_allclose
def get_window_size(causal, window): def get_window_size(causal, window):
if window > 0: if window > 0:
...@@ -116,14 +116,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -116,14 +116,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
out_flash, lse_flash = flash_attn() out_flash, lse_flash = flash_attn()
if has_bwd: if has_bwd:
out_flash.backward(grad_out, retain_graph=True) out_flash.backward(grad_out, retain_graph=True)
dq1 = q1.grad.clone() _dq1 = q1.grad.clone()
dk1 = k1.grad.clone() dk1 = k1.grad.clone()
dv1 = v1.grad.clone() dv1 = v1.grad.clone()
if check_correctness: if check_correctness:
out_torch, lse_torch = torch_attn() out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) assert check_is_allclose("out", out_flash.float(), out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) assert check_is_allclose("lse", lse_flash.float(), lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd: if has_bwd:
out_torch.backward(grad_out, retain_graph=True) out_torch.backward(grad_out, retain_graph=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment