Commit 1858932a authored by Jiashi Li's avatar Jiashi Li
Browse files

Code format

parent 7f55c715
import enum
import torch import torch
def quantize_k_cache( def quantize_k_cache(
...@@ -19,20 +17,20 @@ def quantize_k_cache( ...@@ -19,20 +17,20 @@ def quantize_k_cache(
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size() input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv] result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:] result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1) result = result.view(num_blocks, block_size, 1, -1)
return result return result
...@@ -55,14 +53,14 @@ def dequantize_k_cache( ...@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv] input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16)
result[..., dv:] = input_rope result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1) cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d) result = result.view(num_blocks, block_size, 1, d)
return result return result
...@@ -2,20 +2,20 @@ import argparse ...@@ -2,20 +2,20 @@ import argparse
import math import math
import random import random
import dataclasses import dataclasses
from typing import Optional, Tuple, List from typing import Optional, Tuple
import torch import torch
import triton import triton
import quant
import flash_mla import flash_mla
import quant
from lib import cdiv, check_is_allclose from lib import cdiv, check_is_allclose
@dataclasses.dataclass @dataclasses.dataclass
class TestParam: class TestParam:
b: int # Batch size b: int # Batch size
s_q: int # Number of queries for one request s_q: int # Number of queries for one request
s_k: int # Seq len, or mean seq len if varlen == True s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool is_varlen: bool
is_causal: bool is_causal: bool
is_fp8: bool is_fp8: bool
...@@ -24,8 +24,8 @@ class TestParam: ...@@ -24,8 +24,8 @@ class TestParam:
is_all_indices_invalid: bool = False is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False have_zero_seqlen_k: bool = False
block_size: int = 64 block_size: int = 64
h_q: int = 128 # Number of q heads h_q: int = 128 # Number of q heads
h_kv: int = 1 # Number of kv heads h_kv: int = 1 # Number of kv heads
d: int = 576 # Q/K head dim (= dv + RoPE dim) d: int = 576 # Q/K head dim (= dv + RoPE dim)
dv: int = 512 # V head dim dv: int = 512 # V head dim
seed: int = 0 seed: int = 0
...@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks = cdiv(cur_len, t.block_size) cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan") blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0: if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000 block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None return cache_seqlens, q, block_table, blocked_k, None, None
else: else:
...@@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices # Generate indices
for j in range(t.s_q): for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk: if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices) pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV # Mask KV
perm = torch.randperm(t.topk, device='cpu') perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm] cur_abs_indices = cur_abs_indices[perm]
...@@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
abs_indices[i, j, :] = cur_abs_indices abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN # Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist() all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices)) all_indices = list(set(all_indices))
...@@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d) blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device) abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device) indices_in_kvcache = indices_in_kvcache.to(q.device)
...@@ -139,7 +139,7 @@ def reference_torch( ...@@ -139,7 +139,7 @@ def reference_torch(
valid_indices = cur_indices[cur_indices != -1] valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True mask[i, valid_indices] = True
return mask return mask
def scaled_dot_product_attention( def scaled_dot_product_attention(
batch_idx: int, batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d] query: torch.Tensor, # [h_q, s_q, d]
...@@ -157,7 +157,7 @@ def reference_torch( ...@@ -157,7 +157,7 @@ def reference_torch(
if h_kv != 1: if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0 kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None: if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool) mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal: if is_causal:
...@@ -169,14 +169,14 @@ def reference_torch( ...@@ -169,14 +169,14 @@ def reference_torch(
attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype) attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1)) attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k # Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf")) lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf") lse[lonely_q_mask] = float("+inf")
return output, lse return output, lse
b, s_q, h_q, d = q.size() b, s_q, h_q, d = q.size()
...@@ -202,7 +202,7 @@ def reference_torch( ...@@ -202,7 +202,7 @@ def reference_torch(
lse_ref[i] = cur_lse lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16) out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref return out_ref, lse_ref
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(t: TestParam): def test_flash_mla(t: TestParam):
...@@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam): ...@@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam):
def run_flash_mla(): def run_flash_mla():
return flash_mla.flash_mla_with_kvcache( return flash_mla.flash_mla_with_kvcache(
q, q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
block_table, block_table,
cache_seqlens, cache_seqlens,
t.dv, t.dv,
...@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam): ...@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans, lse_ans = run_flash_mla() out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
if t.test_performance: if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([ compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2*t.d*mean_attended_seqlens, # Q * K^T 2 * t.d * mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V 2 * mean_attended_seqlens * t.dv, # attention * V
]) ])
q_elem_size = torch.bfloat16.itemsize q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize
memory_volume_B = t.b*sum([ memory_volume_B = t.b * sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output t.s_q * t.h_q * (t.dv * q_elem_size), # Output
]) ])
achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9 achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype): def main(torch_dtype):
...@@ -324,7 +324,7 @@ def main(torch_dtype): ...@@ -324,7 +324,7 @@ def main(torch_dtype):
cc_major, cc_minor = torch.cuda.get_device_capability() cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10: if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases: for testcase in testcases:
test_flash_mla(testcase) test_flash_mla(testcase)
......
...@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase: ...@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed) torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed) torch.cuda.manual_seed(t.seed)
random.seed(t.seed) random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10) q.clamp_(-10, 10)
kv.clamp_(-10, 10) kv.clamp_(-10, 10)
...@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase: ...@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk] cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk: if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)] cur_indices = cur_indices[torch.randperm(t.topk)]
...@@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float: ...@@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float:
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1 assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk] indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
...@@ -104,15 +104,15 @@ def run_test(p: TestParam) -> bool: ...@@ -104,15 +104,15 @@ def run_test(p: TestParam) -> bool:
return flash_mla_sparse_fwd( return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
) )
ans_out, ans_max_logits, ans_lse = run_ans() ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize() torch.cuda.synchronize()
if p.benchmark: if p.benchmark:
flop = get_flop(p) flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12 prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness: if p.check_correctness:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool: ...@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch.cuda.synchronize() torch.cuda.synchronize()
is_correct = True is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct return is_correct
else: else:
...@@ -187,11 +187,10 @@ if __name__ == '__main__': ...@@ -187,11 +187,10 @@ if __name__ == '__main__':
is_correct = run_test(test) is_correct = run_test(test)
if not is_correct: if not is_correct:
failed_cases.append(test) failed_cases.append(test)
if len(failed_cases) > 0: if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases: for case in failed_cases:
print(f" {case}") print(f" {case}")
else: else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
...@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint ...@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import triton import triton
from flash_mla import flash_attn_varlen_func from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose from lib import check_is_allclose
def get_window_size(causal, window): def get_window_size(causal, window):
...@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal, window) == 0).sum().item() for i in range(b)]) causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d)/10 q = torch.randn(total_q, h, d) / 10
k = torch.randn(total_k, h_k, d)/10 k = torch.randn(total_k, h_k, d) / 10
v = torch.randn(total_k, h_k, dv)/10 v = torch.randn(total_k, h_k, dv) / 10
grad_out = torch.randn(total_q, h, dv)/10 grad_out = torch.randn(total_q, h, dv) / 10
softmax_scale = (d + 100) ** (-0.5) softmax_scale = (d + 100) ** (-0.5)
q1 = q.clone().requires_grad_() q1 = q.clone().requires_grad_()
...@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if check_correctness: if check_correctness:
out_torch, lse_torch = torch_attn() out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd: if has_bwd:
out_torch.backward(grad_out, retain_graph=True) out_torch.backward(grad_out, retain_graph=True)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
def forward(): def forward():
return flash_attn() return flash_attn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment