Commit 1858932a authored by Jiashi Li's avatar Jiashi Li
Browse files

Code format

parent 7f55c715
import enum
import torch
def quantize_k_cache(
......@@ -19,20 +17,20 @@ def quantize_k_cache(
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype)
result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
......@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16)
input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d)
return result
......@@ -2,20 +2,20 @@ import argparse
import math
import random
import dataclasses
from typing import Optional, Tuple, List
from typing import Optional, Tuple
import torch
import triton
import quant
import flash_mla
import quant
from lib import cdiv, check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int # Batch size
s_q: int # Number of queries for one request
s_k: int # Seq len, or mean seq len if varlen == True
b: int # Batch size
s_q: int # Number of queries for one request
s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool
is_causal: bool
is_fp8: bool
......@@ -24,8 +24,8 @@ class TestParam:
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
block_size: int = 64
h_q: int = 128 # Number of q heads
h_kv: int = 1 # Number of kv heads
h_q: int = 128 # Number of q heads
h_kv: int = 1 # Number of kv heads
d: int = 576 # Q/K head dim (= dv + RoPE dim)
dv: int = 512 # V head dim
seed: int = 0
......@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan")
blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None
else:
......@@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size)
cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV
perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm]
......@@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
......@@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
......@@ -139,7 +139,7 @@ def reference_torch(
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
......@@ -157,7 +157,7 @@ def reference_torch(
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
......@@ -169,14 +169,14 @@ def reference_torch(
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
......@@ -202,7 +202,7 @@ def reference_torch(
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
@torch.inference_mode()
def test_flash_mla(t: TestParam):
......@@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam):
def run_flash_mla():
return flash_mla.flash_mla_with_kvcache(
q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
block_table,
cache_seqlens,
t.dv,
......@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore
time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([
2*t.d*mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V
compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2 * t.d * mean_attended_seqlens, # Q * K^T
2 * mean_attended_seqlens * t.dv, # attention * V
])
q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize
memory_volume_B = t.b*sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output
kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize
memory_volume_B = t.b * sum([
t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q * t.h_q * (t.dv * q_elem_size), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype):
......@@ -324,7 +324,7 @@ def main(torch_dtype):
cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases:
test_flash_mla(testcase)
......
......@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
......@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
......@@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float:
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk]
indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
......@@ -104,15 +104,15 @@ def run_test(p: TestParam) -> bool:
return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize()
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops")
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
......@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct
else:
......@@ -187,11 +187,10 @@ if __name__ == '__main__':
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
......@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import triton
from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose
def get_window_size(causal, window):
......@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d)/10
k = torch.randn(total_k, h_k, d)/10
v = torch.randn(total_k, h_k, dv)/10
grad_out = torch.randn(total_q, h, dv)/10
q = torch.randn(total_q, h, d) / 10
k = torch.randn(total_k, h_k, d) / 10
v = torch.randn(total_k, h_k, dv) / 10
grad_out = torch.randn(total_q, h, dv) / 10
softmax_scale = (d + 100) ** (-0.5)
q1 = q.clone().requires_grad_()
......@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if check_correctness:
out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536)
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd:
out_torch.backward(grad_out, retain_graph=True)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
def forward():
return flash_attn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment