Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
import torch
def _get_new_non_contiguous_tensor_shape(shape):
"""
Get the expanded shape for a non-contiguous tensor.
The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1
"""
return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)]
def gen_non_contiguous_randn_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.randn(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def gen_non_contiguous_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.empty(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:
new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device)
new_tensor[:] = tensor
return new_tensor
import torch
_is_low_precision_mode_stack = []
class LowPrecisionMode:
def __init__(self, enabled: bool = True):
self.enabled = enabled
def __enter__(self):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.append(self.enabled)
def __exit__(self, exc_type, exc_value, traceback):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.pop()
def is_low_precision_mode() -> bool:
global _is_low_precision_mode_stack
if len(_is_low_precision_mode_stack) == 0:
return False
return _is_low_precision_mode_stack[-1]
def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting."
if is_low_precision_mode():
tensor_bf16 = tensor.to(torch.bfloat16)
tensor_fp32 = tensor_bf16.to(torch.float32)
return tensor_fp32
else:
return tensor
import os
import functools
colors = {
'RED_FG': '\033[31m',
'GREEN_FG': '\033[32m',
'CYAN_FG': '\033[36m',
'GRAY_FG': '\033[90m',
'YELLOW_FG': '\033[33m',
'RED_BG': '\033[41m',
'GREEN_BG': '\033[42m',
'CYAN_BG': '\033[46m',
'YELLOW_BG': '\033[43m',
'GRAY_BG': '\033[100m',
'CLEAR': '\033[0m'
}
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
@functools.lru_cache()
def is_using_profiling_tools() -> bool:
"""
Return whether we are running under profiling tools like nsys or ncu
NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it
"""
is_using_nsys = os.environ.get('NSYS_PROFILING_SESSION_ID') is not None
is_using_ncu = os.environ.get('NV_COMPUTE_PROFILER_PERFWORKS_DIR') is not None
is_using_compute_sanitizer = os.environ.get('NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN') is not None
return is_using_nsys or is_using_ncu or is_using_compute_sanitizer
def set_random_seed(seed: int):
import random
import numpy as np
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
class Counter:
def __init__(self):
self.count = 0
def next(self) -> int:
self.count += 1
return self.count - 1
from typing import List
import dataclasses
import os
import enum
from typing import List, Optional
import random
import torch
import kernelkit as kk
import flash_mla
def cdiv(x: int, y: int):
return (x+y-1) // y
import quant
def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):
class TestTarget(enum.Enum):
FWD = 0
DECODE = 1
@dataclasses.dataclass
class ExtraTestParamForDecode:
b: int
is_varlen: bool
have_zero_seqlen_k: bool
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
@dataclasses.dataclass
class TestParam:
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 512
d_v: int = 512
seed: int = -1 # -1: to be filled automatically
check_correctness: bool = True
is_all_indices_invalid: bool = False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647)
num_runs: int = 10
have_attn_sink: bool = False
have_topk_length: bool = False
decode: Optional[ExtraTestParamForDecode] = None
@dataclasses.dataclass
class RawTestParamForDecode:
"""
"Flattened" test parameters for decoding test
In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test.
"""
b: int
h_q: int
s_q: int
h_kv: int
s_kv: int
is_varlen: bool
topk: int
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
have_topk_length: bool = False
enable_attn_sink: bool = True
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
d_qk: int = 576 # Q/K head dim (= dv + RoPE dim)
d_v: int = 512 # V head dim
check_correctness: bool = True
num_runs: int = 10
seed: int = -1
def to_test_param(self) -> TestParam:
return TestParam(
self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v,
self.seed, self.check_correctness,
self.is_all_indices_invalid,
self.num_runs,
self.enable_attn_sink,
self.have_topk_length,
decode = ExtraTestParamForDecode(
self.b, self.is_varlen, self.have_zero_seqlen_k,
self.extra_s_k, self.extra_topk,
self.block_size, self.extra_block_size, self.have_extra_topk_length
)
)
@dataclasses.dataclass
class Testcase:
p: TestParam
dOut: torch.Tensor # [s_q, h_q, d_v]
q: torch.Tensor # [s_q, h_q, d_qk]
kv: torch.Tensor # [s_kv, h_kv, d_qk]
indices: torch.Tensor # [s_q, h_kv, topk]
sm_scale: float
attn_sink: Optional[torch.Tensor] # [h_q]
topk_length: Optional[torch.Tensor] # [s_q]
def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor:
"""
Generate random permutations in batch
The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds.
Values within each row are unique.
If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`.
"""
assert not torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(True)
perm_range_max = max(int(torch.max(perm_range).item()), perm_size)
rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32)
rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first
res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32)
if len(paddings) == 1:
res[res >= perm_range.view(batch_size, 1)] = paddings[0]
else:
fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32))
res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers)
torch.use_deterministic_algorithms(False)
return res
def generate_testcase(t: TestParam) -> Testcase:
kk.set_random_seed(t.seed)
q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
do.clamp_(-10, 10)
invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647]
indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk)
if t.is_all_indices_invalid:
all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2
indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate)
indices = indices.to(q.device)
attn_sink = None
if t.have_attn_sink:
attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
mask = torch.randn((t.h_q, ), dtype=torch.float32)
attn_sink[mask < -0.5] = float("-inf")
attn_sink[mask > +0.5] = float("+inf")
topk_length = None
if t.have_topk_length:
topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk)
q = kk.non_contiguousify(q)
kv = kk.non_contiguousify(kv)
do = kk.non_contiguousify(do)
indices = kk.non_contiguousify(indices)
return Testcase(
p=t,
dOut=do,
q=q,
kv=kv,
indices=indices,
sm_scale=0.5, # Otherwise dK is too small compared to dV
attn_sink=attn_sink,
topk_length=topk_length
)
@dataclasses.dataclass
class KVScope:
t: TestParam
cache_seqlens: torch.Tensor
block_table: torch.Tensor
blocked_k: torch.Tensor
abs_indices: torch.Tensor
indices_in_kvcache: torch.Tensor
topk_length: Optional[torch.Tensor]
blocked_k_quantized: Optional[torch.Tensor] = None
def quant_and_dequant_(self):
"""
For FP8 cases, we need to quantize the KV cache for Flash MLA.
Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error
"""
fp8_kvcache_layout = None
if self.t.d_qk == 576:
fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse
elif self.t.d_qk == 512:
assert self.abs_indices is not None
fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse
else:
assert False
self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout)
blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout)
self.blocked_k = blocked_k_dequantized
def get_kvcache_for_flash_mla(self) -> torch.Tensor:
"""
Return the quantized blocked_k for Flash MLA
"""
assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`"
return self.blocked_k_quantized
def apply_perm(self, perm: torch.Tensor) -> "KVScope":
"""
Apply a batch permutation to this KVScope. Used for batch-invariance test
"""
new_kvscope = KVScope(
self.t,
self.cache_seqlens[perm],
self.block_table[perm],
self.blocked_k,
self.abs_indices[perm],
self.indices_in_kvcache[perm],
self.topk_length[perm] if self.topk_length is not None else None,
self.blocked_k_quantized
)
return new_kvscope
@dataclasses.dataclass
class TestcaseForDecode:
p: TestParam
q: torch.Tensor # [b, s_q, h_q, d_qk]
attn_sink: Optional[torch.Tensor] # [h_q]
sm_scale: float
kv_scope: KVScope
extra_kv_scope: Optional[KVScope]
def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
kk.set_random_seed(t.seed)
assert t.h_q % t.h_kv == 0
assert t.decode is not None
q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk))
q.clamp_(min=-1.0, max=1.0)
attn_sink = None
if t.have_attn_sink:
attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
inf_mask = torch.randn((t.h_q, ), dtype=torch.float32)
attn_sink[inf_mask > 0.5] = float("inf")
attn_sink[inf_mask < -0.5] = float("-inf")
def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope:
b = t.decode.b # type: ignore
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu')
if is_varlen:
for i in range(b):
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q)
if have_zero_seqlen:
zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen_alignment = 4 * block_size
max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment
cache_seqlens = cache_seqlens_cpu.cuda()
assert max_seqlen_pad % block_size == 0
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32)
if is_all_indices_invalid:
abs_indices.fill_(-1)
else:
abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk)
indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size)
topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None
# Mask nonused KV as NaN
if have_topk_length:
indices_in_kvcache_masked = indices_in_kvcache.clone()
indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1
else:
indices_in_kvcache_masked = indices_in_kvcache
blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[indices_in_kvcache_masked] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk)
block_table = kk.non_contiguousify(block_table)
abs_indices = kk.non_contiguousify(abs_indices)
indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache)
return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)
kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)
kv_scope0.quant_and_dequant_()
if t.decode.extra_topk is not None:
if t.decode.extra_s_k is None:
t.decode.extra_s_k = t.decode.extra_topk*2
if t.decode.extra_block_size is None:
t.decode.extra_block_size = t.decode.block_size
kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)
kv_scope1.quant_and_dequant_()
else:
assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length
kv_scope1 = None
sm_scale = t.d_qk ** -0.55
q = kk.non_contiguousify(q)
return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1)
def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
assert not return_p_sum
return flash_mla.flash_mla_sparse_fwd(
t.q, t.kv, t.indices,
sm_scale=t.sm_scale,
attn_sink=t.attn_sink,
topk_length=t.topk_length
)
def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):
assert p.decode is not None
return flash_mla.flash_mla_with_kvcache(
t.q,
t.kv_scope.get_kvcache_for_flash_mla(),
None, None, p.d_v,
tile_scheduler_metadata, num_splits,
t.sm_scale, False, True,
t.kv_scope.indices_in_kvcache,
t.attn_sink,
t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None,
t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,
t.kv_scope.topk_length,
t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None
)
@dataclasses.dataclass
class FlopsAndMemVolStatistics:
"""
Check if two tensors are close enough
FLOPs and memory volume statistics for prefilling
"""
def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float:
fwd_flop: float
fwd_mem_vol: float
def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics:
total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item()
indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv)
if t.topk_length is not None:
indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None]
num_valid_indices = indices_valid_mask.sum().item()
fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v)
fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2
return FlopsAndMemVolStatistics(
fwd_flop,
fwd_mem_vol,
)
@dataclasses.dataclass
class FlopsAndMemVolStatisticsForDecode:
"""
Calculate the cosine diff between two tensors
FLOPs and memory volume statistics for decoding
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum().item()
if denominator == 0:
return 0
sim = 2 * (x * y).sum().item() / denominator
return 1 - sim
assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
ans = ans.clone().to(torch.float)
ref = ref.clone().to(torch.float)
# Deal with anomalies
def deal_with_anomalies(val: float):
ref_mask = (ref == val) if (val == val) else (ref != ref)
ans_mask = (ans == val) if (val == val) else (ans != ans)
ref[ref_mask] = 0.0
ans[ans_mask] = 0.0
if not torch.equal(ref_mask, ans_mask):
print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref")
return False
return True
anomalies_check_passed = True
anomalies_check_passed &= deal_with_anomalies(float("inf"))
anomalies_check_passed &= deal_with_anomalies(float("-inf"))
anomalies_check_passed &= deal_with_anomalies(float("nan"))
if not anomalies_check_passed:
return False
cos_diff = get_cos_diff(ans, ref)
raw_abs_err = torch.abs(ans-ref)
raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))
rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)
abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)
pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)
if not pass_mask.all():
print(f"`{name}` mismatch")
max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item() # type: ignore
max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item() # type: ignore
def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:
result = []
for size in t.shape[::-1]:
result.append(pos % size)
pos = pos // size
assert pos == 0
return result[::-1]
print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}")
print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}")
print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)")
print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})")
return False
flop: float
mem_vol: float
def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode:
assert p.decode
b = p.decode.b
def get_num_attended_tokens(kv_scope: KVScope) -> int:
topk = kv_scope.indices_in_kvcache.shape[-1]
if kv_scope.topk_length is None:
return b * p.s_q * topk
else:
return int(kv_scope.topk_length.sum().item()) * p.s_q
def get_num_retrieved_tokens(kv_scope: KVScope) -> int:
if kv_scope.topk_length is None:
indices = kv_scope.indices_in_kvcache
else:
if abs(cos_diff) > cos_diff_tol:
print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})")
return False
return True
\ No newline at end of file
indices = kv_scope.indices_in_kvcache.clone()
batch, s_q, topk = indices.shape
mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1)
indices[mask] = -1
num_unique_tokens = indices.unique().numel() # type: ignore
return num_unique_tokens
num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v)
kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache
mem_vol = sum([
2 * b * p.s_q * p.h_q * p.d_qk, # Q
num_retrieved_tokens * kv_token_size, # K
2 * b * p.s_q * p.h_q * p.d_v, # O
])
return FlopsAndMemVolStatisticsForDecode(
compute_flop,
mem_vol
)
def is_no_cooldown() -> bool:
return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y']
import enum
from typing import Tuple
import torch
class FP8KVCacheLayout(enum.Enum):
V32_FP8Sparse = 1
MODEL1_FP8Sparse = 2
def get_meta(self) -> Tuple[int, int, int, int, int]:
# Return: (d, d_nope, d_rope, tile_size, num_tiles)
return {
FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4),
FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7)
}[self]
def _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch.float32) -> torch.Tensor:
return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype)
def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
tile_size: int = 128,
kvcache_layout: FP8KVCacheLayout,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()
assert input_k_cache.shape[-1] == d
num_blocks, block_size, h_k, _ = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:
bytes_per_token = d_nope + num_tiles*4 + input_elem_size*d_rope
result = torch.empty((num_blocks, block_size+1, bytes_per_token), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size, :]
result_k_nope_part = result[..., :d_nope]
result_k_scale_factor = result[..., d_nope: d_nope + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., d_nope + num_tiles*4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., d_nope:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:
bytes_per_token = d_nope + 2*d_rope + num_tiles + 1
size_per_block_padded = (block_size*bytes_per_token + 576-1) // 576 * 576
result = torch.empty((num_blocks, size_per_block_padded), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size*bytes_per_token]
result_k_nope_rope_part = result[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)
result_k_nope = result_k_nope_rope_part[:, :, :d_nope] # [num_blocks, block_size, d_nope]
result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view(input_k_cache.dtype) # [num_blocks, block_size, d_rope]
result_k_scale_factor = result[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles]
result_k_rope[:] = input_k_cache[..., d_nope:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to(torch.float8_e8m0fnu)
cur_scale_factors_inv = cur_scale_factors_inv.view(num_blocks, block_size, 1)
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope[:, :, tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
else:
raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}")
def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576
kvcache_layout: FP8KVCacheLayout,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device)
if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16)
result[..., dv:] = input_rope
input_nope = quant_k_cache[..., :d_nope]
input_scale = quant_k_cache[..., d_nope:d_nope + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., d_nope + num_tiles*4:].view(torch.bfloat16)
result[..., d_nope:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32)
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:
quant_k_cache = quant_k_cache.view(num_blocks, -1) # [num_blocks, ...]
input_nope_rope = quant_k_cache[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)
input_nope = input_nope_rope[:, :, :d_nope]
input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16)
input_scale = quant_k_cache[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles]
result[..., d_nope:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.bfloat16)
cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1)
result[..., tile_idx*tile_size: (tile_idx+1)*tile_size] = cur_nope * cur_scales
else:
raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}")
result = result.view(num_blocks, block_size, 1, d)
return result
def abs_indices2indices_in_kvcache(
abs_indices: torch.Tensor, # [b, s_q, topk]
block_table: torch.Tensor, # [b, /]
block_size: int,
) -> torch.Tensor:
"""
Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel
Equivalent to:
b, s_q, topk = abs_indices.shape
indices_in_kvcache = torch.empty_like(abs_indices)
for i in range(b):
cur_abs_indices = abs_indices[i, :, :].clone() # [s_q, topk]
invalid_mask = cur_abs_indices == -1
cur_abs_indices[invalid_mask] = 0
cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size
cur_indices_in_kvcache[invalid_mask] = -1
indices_in_kvcache[i] = cur_indices_in_kvcache
return indices_in_kvcache
"""
b, s_q, topk = abs_indices.shape
_, max_blocks_per_seq = block_table.shape
abs_indices = abs_indices.clone()
invalid_mask = abs_indices == -1
abs_indices[invalid_mask] = 0
real_block_idxs = block_table.view(-1).index_select(0, (abs_indices//block_size + torch.arange(0, b).view(b, 1, 1)*max_blocks_per_seq).view(-1))
indices_in_kvcache = real_block_idxs.view(b, s_q, topk)*block_size + abs_indices%block_size
indices_in_kvcache[invalid_mask] = -1
return indices_in_kvcache
\ No newline at end of file
from typing import Optional, Tuple
import torch
from lib import TestParam, Testcase, TestcaseForDecode, KVScope
def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor:
if lse1 is None:
return lse0
else:
return torch.logsumexp(
torch.stack([
lse0.view(s_q, h_q),
lse1.broadcast_to(s_q, h_q)
], dim=0),
dim=0
)
def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
- o: [s_q, h_q, dv]
- o_fp32: [s_q, h_q, dv]
- max_logits: [s_q, h_q]
- lse: [s_q, h_q]
"""
indices = t.indices.clone().squeeze(1)
if t.topk_length is not None:
mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1) # [s_q, topk]
indices[mask] = -1
invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk]
indices[invalid_mask] = 0
q = t.q.float()
gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float() # [s_q, topk, d_qk]
P = (q @ gathered_kv.transpose(1, 2)) # [s_q, h_q, topk]
P *= t.sm_scale
P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf")
orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q]
max_logits = P.max(dim=-1).values # [s_q, h_q]
lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q)
if not torch.is_inference_mode_enabled():
lse_for_o = lse_for_o.clone()
lse_for_o[lse_for_o == float("-inf")] = float("+inf") # So that corresponding O will be 0
s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))
out = s_for_o @ gathered_kv[..., :p.d_v] # [s_q, h_q, dv]
lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q]
orig_lse[lonely_q_mask] = float("+inf")
return (out.to(torch.bfloat16), out, max_logits, orig_lse)
def ref_sparse_attn_decode(
p: TestParam,
t: TestcaseForDecode
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation of sparse decoding attention in PyTorch
"""
assert p.h_kv == 1
assert p.decode is not None
b = p.decode.b
def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]:
assert kv_scope.indices_in_kvcache is not None
topk = kv_scope.indices_in_kvcache.size(-1)
indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain
gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk) # [b, s_q, topk, d]
invalid_mask = kv_scope.indices_in_kvcache == -1
if kv_scope.topk_length is not None:
invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1)
return gathered_kv, invalid_mask
gathered_kv, invalid_mask = process_kv_scope(t.kv_scope)
if t.extra_kv_scope is not None:
gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope)
gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) # [b, s_q, topk+extra_topk, d]
invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) # [b, s_q, topk+extra_topk]
gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float()
gathered_kv[gathered_kv != gathered_kv] = 0.0
q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk)
attn_weight = q @ gathered_kv.transpose(-1, -2) # [t.b*t.s_q, t.h_q, topk+extra_topk]
attn_weight *= t.sm_scale
attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf")
lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q]
attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1))
output = attn_weight @ gathered_kv[..., :p.d_v] # [t.b*t.s_q, t.h_q, t.dv]
output = output.view(b, p.s_q, p.h_q, p.d_v)
lse = lse.view(b, p.s_q, p.h_q)
# Attention sink
if t.attn_sink is not None:
output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1)
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output.to(torch.bfloat16), lse.transpose(1, 2)
\ No newline at end of file
......@@ -2,14 +2,12 @@ import argparse
import math
import random
import dataclasses
from typing import Optional, Tuple
from typing import Tuple
import torch
import triton
import kernelkit as kk
import flash_mla
import quant
from lib import cdiv, check_is_allclose
@dataclasses.dataclass
class TestParam:
......@@ -18,10 +16,7 @@ class TestParam:
s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool
is_causal: bool
is_fp8: bool
topk: Optional[int] = None
test_performance: bool = True
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
block_size: int = 64
h_q: int = 128 # Number of q heads
......@@ -31,7 +26,7 @@ class TestParam:
seed: int = 0
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate test data from a given configuration
Return: [cache_seqlens, q, block_table, blocked_k]
......@@ -53,11 +48,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen = cache_seqlens_cpu.max().item()
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
max_seqlen = int(cache_seqlens_cpu.max().item())
max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(t.b, t.s_q, t.h_q, t.d)
q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10
q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)
......@@ -65,59 +60,14 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
if t.topk is None:
for i in range(t.b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, t.block_size)
cur_len = int(cache_seqlens_cpu[i].item())
cur_num_blocks = kk.cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
for i in range(t.b):
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV
perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
# Fill it with invalid indices if needed
if t.is_all_indices_invalid:
cur_abs_indices.fill_(-1)
cur_blocked_indices.fill_(-1)
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache
return cache_seqlens, q, block_table, blocked_k
def reference_torch(
......@@ -127,18 +77,10 @@ def reference_torch(
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
......@@ -146,7 +88,6 @@ def reference_torch(
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
......@@ -158,13 +99,10 @@ def reference_torch(
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
if is_causal and query.size(1) > 1:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
......@@ -186,8 +124,8 @@ def reference_torch(
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_len = int(cache_seqlens_cpu[i].item())
cur_num_blocks = kk.cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
......@@ -195,12 +133,11 @@ def reference_torch(
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None
is_causal
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
out_ref = out_ref.to(q.dtype)
return out_ref, lse_ref
......@@ -211,58 +148,42 @@ def test_flash_mla(t: TestParam):
# Generating test data
torch.cuda.synchronize()
cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t)
cache_seqlens, q, block_table, blocked_k, = generate_test_data(t)
if t.is_fp8:
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128)
blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
cache_seqlens,
t.s_q * t.h_q // t.h_kv,
t.h_kv,
t.h_q,
t.is_fp8,
t.topk
)
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
def run_flash_mla():
return flash_mla.flash_mla_with_kvcache(
q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
blocked_k,
block_table,
cache_seqlens,
t.dv,
tile_scheduler_metadata,
num_splits,
causal=t.is_causal,
is_fp8_kvcache=t.is_fp8,
indices=indices_in_kvcache
causal=t.is_causal
)
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal)
is_correct = True
is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
assert is_correct
if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel")
mean_attended_seqlens = cache_seqlens.float().mean().item()
compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2 * t.d * mean_attended_seqlens, # Q * K^T
2 * mean_attended_seqlens * t.dv, # attention * V
])
q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize
kv_token_size = t.d * torch.bfloat16.itemsize
memory_volume_B = t.b * sum([
t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q * t.h_q * (t.dv * q_elem_size), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
......@@ -277,54 +198,39 @@ def main(torch_dtype):
torch.set_default_device(device)
torch.cuda.set_device(device)
cc_major, cc_minor = torch.cuda.get_device_capability()
assert cc_major == 9, "Dense MLA decoding is only supported on sm90 (Hopper) currently."
correctness_cases = [
TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False)
TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv)
for b in [1, 2, 6, 64]
for s_q in [1, 2, 4]
for s_k in [20, 140, 4096]
for h_q in [1, 3, 9, 63, 64, 126, 128]
for h_kv in [1, 2, 3, 8]
for is_varlen in [False, True]
for is_causal in [False, True]
for (is_fp8, topk) in [
(False, None),
(True, 128),
(True, 2048)
]
if not (is_causal and topk is not None)
if h_q % h_kv == 0
]
corner_cases = [
# Cases where all topk indices are invalid
TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True)
for topk in [128, 2048, 4096]
] + [
# Cases where some kv cache have zero length
TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 128),
(False, True, 2048),
]
TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv)
for h_q in [1, 3, 9, 63, 64, 126, 128]
for h_kv in [1, 2, 3, 8]
for is_causal in [False, True]
if h_q % h_kv == 0
]
performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 2048),
]
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True)
for is_causal in [False, True]
for s_q in [1, 2]
for s_k in [4096, 8192, 16384, 32768]
]
testcases = correctness_cases + corner_cases + performance_cases
# Prune out unsupported cases
cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases:
test_flash_mla(testcase)
......
import math
import time
from typing import Tuple
import random
import dataclasses
import torch
import triton
from flash_mla import flash_mla_sparse_fwd
from lib import check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 576
d_v: int = 512
seed: int = 0
check_correctness: bool = True
benchmark: bool = True
@dataclasses.dataclass
class Testcase:
t: TestParam
q: torch.Tensor
kv: torch.Tensor
indices: torch.Tensor
def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32)
for b in range(t.b):
for s in range(t.s_q):
for h in range(t.h_kv):
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
indices[b, s, h] = cur_indices
indices = indices.to(q.device)
return Testcase(
t=t,
q=q,
kv=kv,
indices=indices
)
def get_flop(p: TestParam) -> float:
flop = 2 * sum([
p.h_q * p.d_qk * p.topk,
p.h_q * p.d_v * p.topk
]) * p.b * p.s_q
return flop
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :p.d_v]
return (max_logits, lse, result)
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
assert p.b == 1
t = generate_testcase(p)
sm_scale = 1 / math.sqrt(p.d_qk)
torch.cuda.synchronize()
def run_ans():
return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize()
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale)
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62
]
]
corner_cases = [
# In these cases, some blocks may not have any valid topk indices
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
]
performance_cases = [
TestParam(1, s_q, s_kv, topk, h_q=128)
for s_q in [4096]
for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072]
for topk in [2048]
]
testcases = correctness_cases + corner_cases + performance_cases
failed_cases = []
for test in testcases:
if test.benchmark:
time.sleep(0.2)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
import time
import dataclasses
from typing import Tuple, List, Dict, Optional
import copy
import rich.console
import rich.table
import torch
import kernelkit as kk
import flash_mla
import lib
from lib import TestParam
from lib import RawTestParamForDecode as RawTestParam
import ref
"""
Generate testcase for unit test
"""
def gen_testcase() -> List[RawTestParam]:
correctness_cases = []
corner_cases = []
for d_qk in [576, 512]:
for have_extra_k in ([False, True] if d_qk == 512 else [False]):
for have_extra_topk_len in ([False, True] if have_extra_k else [False]):
for have_topk_len in ([False, True] if d_qk == 512 else [False]):
for h_q in [64, 128]:
cur_correctness_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
have_topk_length=have_topk_len,
enable_attn_sink=True,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0)
for (s_k, topk, block_size) in [
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [1, 3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
]
correctness_cases.extend(cur_correctness_cases)
cur_corner_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
is_all_indices_invalid=is_all_indices_invalid,
have_zero_seqlen_k=have_zero_seqlen_k,
have_topk_length=have_topk_len,
enable_attn_sink=enable_attn_sink,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0,
)
for (s_k, topk, block_size) in [
(512, 64, 61),
(650, 576, 53),
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 61),
(650, 576, 53),
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
for is_all_indices_invalid in [True, False]
for have_zero_seqlen_k in [True, False]
for enable_attn_sink in [True, False]
if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink)
]
corner_cases.extend(cur_corner_cases)
base_and_bszs = [
# V3.2
(RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]),
# MODEL1 CONFIG1
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG2
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG3
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG4
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
]
performance_cases = [
# Production cases
dataclasses.replace(base, b=b)
for base, bszs in base_and_bszs
for b in bszs
] + [
# Peak perf cases
RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)
for h_q in [64, 128]
for d_qk in [512, 576]
]
return correctness_cases + corner_cases + performance_cases
@dataclasses.dataclass
class Result:
is_correct: bool
compute_memory_ratio: float
time_usage_per_us: float
splitkv_time_usage_us: float
combine_time_usage_us: float
achieved_tflops: float
achieved_gBps: float
_counter = kk.Counter()
@torch.inference_mode()
def test_flash_mla(p: TestParam) -> Result:
if p.seed == -1:
global _counter
p.seed = _counter.next()
assert p.decode
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase_for_decode(p)
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
def run_decode():
return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None)
# We first run the kernel once to generate output data for the correctness test
# We must do this first, otherwise when allocating tensors for storing answers,
# it may re-use memory that contains the correct answer, leading to false positives
if p.check_correctness:
torch.cuda.synchronize()
out_ans, lse_ans = run_decode()
torch.cuda.synchronize()
# torch.set_printoptions(profile='full')
# print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7])
# We run the performance test before generating the answer for the correctness test to avoid interference
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
if p.num_runs == 0:
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
else:
result = kk.bench_kineto(run_decode, p.num_runs)
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages
kernel_time_usages_us: Dict[str, Optional[float]] = {}
def pick_kernel_time_usage(kernel_name: str):
t = [kernel_name in s for s in result.get_kernel_names()]
if any(t):
assert sum(t) == 1
kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6
else:
kernel_time_usages_us[kernel_name] = None
pick_kernel_time_usage(splitkv_kernel_name)
pick_kernel_time_usage(combine_kernel_name)
# Get E2E time usages
def have_kernel(name: str):
return kernel_time_usages_us[name] is not None
if kk.is_using_profiling_tools():
e2e_time_usage_us = 1e6
else:
assert have_kernel(splitkv_kernel_name)
if have_kernel(combine_kernel_name):
e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6
else:
e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name]
assert e2e_time_usage_us is not None
flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)
e2e_time_usage_s = e2e_time_usage_us / 1e6
theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol
achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12
achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9
def print_kernel_time_usage(name: str, short_name: str):
if kernel_time_usages_us[name] is not None:
print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')
print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')
print(f'Time (per): {e2e_time_usage_us:.1f} us')
print_kernel_time_usage(splitkv_kernel_name, "Splitkv")
print_kernel_time_usage(combine_kernel_name, "Combine")
print(f'TFlops: {achieved_tflops:.1f}')
print(f'GB/s: {achieved_gBps:.0f}')
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)
is_correct = True
if p.check_correctness:
torch.cuda.synchronize()
with torch.profiler.record_function("reference_flash_mla"):
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct
performance_result.is_correct = is_correct
return performance_result
def main():
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
torch.set_num_threads(32)
raw_testcases = gen_testcase()
testcases = [t.to_test_param() for t in raw_testcases]
print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}")
is_no_cooldown = lib.is_no_cooldown()
num_testcases_len = len(str(len(testcases)))
failed_cases = []
results: List[Tuple[TestParam, Result]] = []
for testcase_idx, testcase in enumerate(testcases):
if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3) # Cooldown
print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='')
result = test_flash_mla(testcase)
results.append((testcase, result))
if not result.is_correct:
failed_cases.append(testcase)
import sys
sys.exit(1)
console = rich.console.Console(width=120)
table = rich.table.Table(show_header=True, header_style="bold cyan")
table.add_column("topk")
table.add_column("Bsz")
table.add_column("h_q&k")
table.add_column("sq")
table.add_column("sk")
table.add_column("d_qk")
table.add_column("Feats")
table.add_column("C/M")
table.add_column("TFlops")
table.add_column("GBps")
table.add_column("us")
table.add_column(" ")
for testcase, result in results:
assert testcase.decode
topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}"
table.add_row(
topk_str,
str(testcase.decode.b),
f"{testcase.h_q:3d} {testcase.h_kv}",
str(testcase.s_q),
str(testcase.s_kv),
str(testcase.d_qk),
" V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length],
f"{result.compute_memory_ratio:3.0f}",
f"{result.achieved_tflops:3.0f}",
f"{result.achieved_gBps:4.0f}",
f"{result.time_usage_per_us:4.1f}",
"" if result.is_correct else "X"
)
console.print(table)
def geomean(l) -> float:
import numpy
return numpy.exp(numpy.mean(numpy.log(l)))
num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True)
num_correctness_cases = sum([1 for t in testcases if t.check_correctness])
if num_correct_testcases == num_correctness_cases:
print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
else:
print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
for t in failed_cases:
print(f"\t{t},")
valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1]
if len(valid_achieved_tflops) > 0:
achieved_tflops_geomean = geomean(valid_achieved_tflops) # > 0.1 to prune out correctness cases
print(f"TFlops geomean: {achieved_tflops_geomean:.1f}")
if __name__ == "__main__":
main()
import time
import sys
import torch
import kernelkit as kk
from lib import TestParam
import lib
import ref
_counter = kk.Counter()
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
if p.seed == -1:
global _counter
p.seed = _counter.next()
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase(p)
torch.cuda.synchronize()
def run_prefill():
return lib.run_flash_mla_sparse_fwd(p, t, False)
prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill()
torch.cuda.synchronize()
if p.num_runs > 0:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd")
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")
if p.check_correctness:
torch.cuda.synchronize()
ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t)
ref_lse[ref_lse == float("-inf")] = float("+inf")
torch.cuda.synchronize()
is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62, 213
]
]
correctness_cases_with_features = [
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [62, 213]
for have_sink_lse in [False, True]
for have_attn_sink in [False, True]
for have_topk_length in [False, True]
]
corner_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_q, s_kv, topk in [
(1, 128, 128),
(1, 256, 256),
(1234, 4321, 4096),
(4096, 2048, 2048)
]
] + [
# In these cases, some blocks may not have any valid topk indices
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
] + [
# In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape
TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
]
performance_case_templates = [
# V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
# MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]),
]
performance_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)
for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
for s_q in [4096]
for s_kv in s_kv_list
]
testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases
is_no_cooldown = lib.is_no_cooldown()
failed_cases = []
for test in testcases:
if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
sys.exit(1)
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
......@@ -5,7 +5,7 @@ from torch.utils.checkpoint import checkpoint
import triton
from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose
from kernelkit import check_is_allclose
def get_window_size(causal, window):
if window > 0:
......@@ -116,14 +116,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
out_flash, lse_flash = flash_attn()
if has_bwd:
out_flash.backward(grad_out, retain_graph=True)
dq1 = q1.grad.clone()
_dq1 = q1.grad.clone()
dk1 = k1.grad.clone()
dv1 = v1.grad.clone()
if check_correctness:
out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
assert check_is_allclose("out", out_flash.float(), out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash.float(), lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd:
out_torch.backward(grad_out, retain_graph=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment