Commit 1858932a authored by Jiashi Li's avatar Jiashi Li
Browse files

Code format

parent 7f55c715
import enum
import torch import torch
def quantize_k_cache( def quantize_k_cache(
...@@ -19,19 +17,19 @@ def quantize_k_cache( ...@@ -19,19 +17,19 @@ def quantize_k_cache(
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size() input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv] result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:] result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1) result = result.view(num_blocks, block_size, 1, -1)
return result return result
...@@ -55,14 +53,14 @@ def dequantize_k_cache( ...@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv] input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16)
result[..., dv:] = input_rope result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles): for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1) cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d) result = result.view(num_blocks, block_size, 1, d)
return result return result
...@@ -2,13 +2,13 @@ import argparse ...@@ -2,13 +2,13 @@ import argparse
import math import math
import random import random
import dataclasses import dataclasses
from typing import Optional, Tuple, List from typing import Optional, Tuple
import torch import torch
import triton import triton
import quant
import flash_mla import flash_mla
import quant
from lib import cdiv, check_is_allclose from lib import cdiv, check_is_allclose
@dataclasses.dataclass @dataclasses.dataclass
...@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks = cdiv(cur_len, t.block_size) cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan") blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0: if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000 block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None return cache_seqlens, q, block_table, blocked_k, None, None
else: else:
...@@ -82,7 +82,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -82,7 +82,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices # Generate indices
for j in range(t.s_q): for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size)
if len(cur_abs_indices) < t.topk: if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices) pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
...@@ -109,7 +109,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. ...@@ -109,7 +109,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d) blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
...@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam): ...@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans, lse_ans = run_flash_mla() out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
if t.test_performance: if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([ compute_volume_flop = t.b * t.h_q * t.s_q * sum([
2*t.d*mean_attended_seqlens, # Q * K^T 2 * t.d * mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V 2 * mean_attended_seqlens * t.dv, # attention * V
]) ])
q_elem_size = torch.bfloat16.itemsize q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize
memory_volume_B = t.b*sum([ memory_volume_B = t.b * sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q t.s_q * t.h_q * (t.d * q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output t.s_q * t.h_q * (t.dv * q_elem_size), # Output
]) ])
achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9 achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype): def main(torch_dtype):
......
...@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase: ...@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed) torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed) torch.cuda.manual_seed(t.seed)
random.seed(t.seed) random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10
q.clamp_(-10, 10) q.clamp_(-10, 10)
kv.clamp_(-10, 10) kv.clamp_(-10, 10)
...@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase: ...@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk] cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk: if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)] cur_indices = cur_indices[torch.randperm(t.topk)]
...@@ -110,9 +110,9 @@ def run_test(p: TestParam) -> bool: ...@@ -110,9 +110,9 @@ def run_test(p: TestParam) -> bool:
if p.benchmark: if p.benchmark:
flop = get_flop(p) flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12 prefill_flops = flop / prefill_ans_time / 1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness: if p.check_correctness:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool: ...@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch.cuda.synchronize() torch.cuda.synchronize()
is_correct = True is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536)
return is_correct return is_correct
else: else:
...@@ -194,4 +194,3 @@ if __name__ == '__main__': ...@@ -194,4 +194,3 @@ if __name__ == '__main__':
print(f" {case}") print(f" {case}")
else: else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
...@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint ...@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import triton import triton
from flash_mla import flash_attn_varlen_func from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose from lib import check_is_allclose
def get_window_size(causal, window): def get_window_size(causal, window):
...@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal, window) == 0).sum().item() for i in range(b)]) causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d)/10 q = torch.randn(total_q, h, d) / 10
k = torch.randn(total_k, h_k, d)/10 k = torch.randn(total_k, h_k, d) / 10
v = torch.randn(total_k, h_k, dv)/10 v = torch.randn(total_k, h_k, dv) / 10
grad_out = torch.randn(total_q, h, dv)/10 grad_out = torch.randn(total_q, h, dv) / 10
softmax_scale = (d + 100) ** (-0.5) softmax_scale = (d + 100) ** (-0.5)
q1 = q.clone().requires_grad_() q1 = q.clone().requires_grad_()
...@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if check_correctness: if check_correctness:
out_torch, lse_torch = torch_attn() out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)
if has_bwd: if has_bwd:
out_torch.backward(grad_out, retain_graph=True) out_torch.backward(grad_out, retain_graph=True)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)
def forward(): def forward():
return flash_attn() return flash_attn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment