Commit 1858932a authored by Jiashi Li's avatar Jiashi Li
Browse files

Code format

parent 7f55c715
import enum
import torch
def quantize_k_cache(
......@@ -19,19 +17,19 @@ def quantize_k_cache(
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype)
result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
......@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16)
input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d)
return result
......@@ -2,13 +2,13 @@ import argparse
import math
import random
import dataclasses
from typing import Optional, Tuple, List
from typing import Optional, Tuple
import torch
import triton
import quant
import flash_mla
import quant
from lib import cdiv, check_is_allclose
@dataclasses.dataclass
......@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan")
blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None
else:
......@@ -82,7 +82,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size)
cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
......@@ -109,7 +109,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
......@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore
time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([
2*t.d*mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V
compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2 * t.d * mean_attended_seqlens, # Q * K^T
2 * mean_attended_seqlens * t.dv, # attention * V
])
q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize
memory_volume_B = t.b*sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output
kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize
memory_volume_B = t.b * sum([
t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q * t.h_q * (t.dv * q_elem_size), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype):
......
......@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
......@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
......@@ -110,9 +110,9 @@ def run_test(p: TestParam) -> bool:
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops")
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
......@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct
else:
......@@ -194,4 +194,3 @@ if __name__ == '__main__':
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
......@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import triton
from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose
def get_window_size(causal, window):
......@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d)/10
k = torch.randn(total_k, h_k, d)/10
v = torch.randn(total_k, h_k, dv)/10
grad_out = torch.randn(total_q, h, dv)/10
q = torch.randn(total_q, h, d) / 10
k = torch.randn(total_k, h_k, d) / 10
v = torch.randn(total_k, h_k, dv) / 10
grad_out = torch.randn(total_q, h, dv) / 10
softmax_scale = (d + 100) ** (-0.5)
q1 = q.clone().requires_grad_()
......@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if check_correctness:
out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536)
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd:
out_torch.backward(grad_out, retain_graph=True)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
def forward():
return flash_attn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment