Commit c28eca99 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Reorganize files and add sparse prefill/decoding kernels on hopper

parent 261330bb
...@@ -6,6 +6,7 @@ import triton ...@@ -6,6 +6,7 @@ import triton
from flash_mla import flash_attn_varlen_func from flash_mla import flash_attn_varlen_func
from lib import check_is_allclose
def get_window_size(causal, window): def get_window_size(causal, window):
if window > 0: if window > 0:
...@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window): ...@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
return attn_bias return attn_bias
def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}"
def sdpa(query, key, value, attn_bias, softmax_scale=None): def sdpa(query, key, value, attn_bias, softmax_scale=None):
query = query.float().transpose(-3, -2)
key = key.float().transpose(-3, -2)
value = value.float().transpose(-3, -2)
key = key.repeat_interleave(h // h_k, dim=-3) key = key.repeat_interleave(h // h_k, dim=-3)
value = value.repeat_interleave(h // h_k, dim=-3) value = value.repeat_interleave(h // h_k, dim=-3)
if softmax_scale is None: if softmax_scale is None:
softmax_scale = query.shape[-1] ** (-0.5) softmax_scale = query.shape[-1] ** (-0.5)
attn_weight = query @ key.transpose(-2, -1) * softmax_scale attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale
attn_weight += attn_bias attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1) lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
...@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs): ...@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, check_correctness: bool = True):
print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}, {has_bwd=}, {check_correctness=}")
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
...@@ -76,16 +71,17 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -76,16 +71,17 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal, window) == 0).sum().item() for i in range(b)]) causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d) q = torch.randn(total_q, h, d)/10
k = torch.randn(total_k, h_k, d) k = torch.randn(total_k, h_k, d)/10
v = torch.randn(total_k, h_k, dv) v = torch.randn(total_k, h_k, dv)/10
grad_out = torch.randn(total_q, h, dv) grad_out = torch.randn(total_q, h, dv)/10
softmax_scale = (d + 100) ** (-0.5) softmax_scale = (d + 100) ** (-0.5)
q1 = q.clone().requires_grad_() q1 = q.clone().requires_grad_()
k1 = k.clone().requires_grad_() k1 = k.clone().requires_grad_()
v1 = v.clone().requires_grad_() v1 = v.clone().requires_grad_()
if check_correctness:
q2 = q.clone().requires_grad_() q2 = q.clone().requires_grad_()
k2 = k.clone().requires_grad_() k2 = k.clone().requires_grad_()
v2 = v.clone().requires_grad_() v2 = v.clone().requires_grad_()
...@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
lse = [] lse = []
for i in range(b): for i in range(b):
OUT, LSE = sdpa_checkpoint( OUT, LSE = sdpa_checkpoint(
q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()],
k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()],
v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()],
attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window),
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
) )
...@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
return out, lse return out, lse
out_flash, lse_flash = flash_attn() out_flash, lse_flash = flash_attn()
out_torch, lse_torch = torch_attn()
assert_close(out_flash, out_torch, "out")
assert_close(lse_flash, lse_torch, "lse")
if has_bwd: if has_bwd:
out_flash.backward(grad_out, retain_graph=True) out_flash.backward(grad_out, retain_graph=True)
out_torch.backward(grad_out, retain_graph=True)
assert_close(q1.grad, q2.grad, "dq")
assert_close(k1.grad, k2.grad, "dk")
assert_close(v1.grad, v2.grad, "dv")
dq1 = q1.grad.clone() dq1 = q1.grad.clone()
dk1 = k1.grad.clone() dk1 = k1.grad.clone()
dv1 = v1.grad.clone() dv1 = v1.grad.clone()
if check_correctness:
out_torch, lse_torch = torch_attn()
assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536)
if has_bwd:
out_torch.backward(grad_out, retain_graph=True)
assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6)
def forward(): def forward():
return flash_attn() return flash_attn()
...@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" assert torch.equal(k1.grad, dk1), "dk deterministic check failed!"
assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" assert torch.equal(v1.grad, dv1), "dv deterministic check failed!"
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# forward()
# if has_bwd:
# backward()
# print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120))
def timer(func, name): def timer(func, name):
t = triton.testing.do_bench(func, warmup=2, rep=3) t = triton.testing.do_bench(func, warmup=2, rep=3)
FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2)))
...@@ -173,18 +166,20 @@ if __name__ == "__main__": ...@@ -173,18 +166,20 @@ if __name__ == "__main__":
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_float32_matmul_precision("high")
b = 4 b = 2
window = 0 window = 0
has_bwd = False has_bwd = False
for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]:
for varlen in [False, True]: for varlen in [False, True]:
for (h, h_k) in [(32, 32), (32, 4)]: for (h, h_k) in [(128, 128), (32, 4)]:
if h != h_k: if h != h_k:
has_bwd = False has_bwd = False
else: else:
has_bwd = True has_bwd = True
for (d, dv) in [(128, 128), (192, 128)]: for (d, dv) in [(128, 128), (192, 128)]:
for causal in [False, True]: for causal in [False, True]:
test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd) skip_correctness_check = mean_sq == 8192 and mean_sk == 8192 and h == 128
test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, not skip_correctness_check)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment