Unverified Commit a9d823b8 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Example] Update GQA varlen fwd (#1173)

* [Example] Update GQA varlen fwd

* fix
parent 298ab480
...@@ -24,21 +24,32 @@ def attention_ref( ...@@ -24,21 +24,32 @@ def attention_ref(
dtype_og = q.dtype dtype_og = q.dtype
if upcast: if upcast:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1] b, T, Hq, D = q.shape
scale = (1.0 / dim)**0.5 S = k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) scale = (1.0 / D)**0.5
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k) scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype) attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None: if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v) output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
...@@ -91,53 +102,53 @@ def flashattn(batch_size, ...@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
})
batch_idx = bz batch_idx = bz
head_idx = by head_idx = by
kv_head_idx = head_idx // groups kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx] q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx] kv_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1] q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1] k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx kv_current_seqlen = k_end_idx - kv_start_idx
v_current_seqlen = v_end_idx - v_start_idx
T.copy( T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared) Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N) loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared) kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i,
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
(bx * block_M + i >= q_current_seqlen or (bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen), k * block_N + j >= kv_current_seqlen), -1e9, 0)
-T.infinity(acc_s.dtype), 0)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen), k * block_N + j >= kv_current_seqlen), -1e9,
-T.infinity(acc_s.dtype), 0) 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -145,6 +156,9 @@ def flashattn(batch_size, ...@@ -145,6 +156,9 @@ def flashattn(batch_size,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -158,11 +172,8 @@ def flashattn(batch_size, ...@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared) kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
...@@ -191,8 +202,7 @@ def main(batch: int = 1, ...@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
causal = False if is_causal:
if causal:
total_flops *= 0.5 total_flops *= 0.5
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -201,9 +211,9 @@ def main(batch: int = 1, ...@@ -201,9 +211,9 @@ def main(batch: int = 1,
device = torch.device("cuda") device = torch.device("cuda")
head_kv = heads // groups head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
...@@ -236,10 +246,10 @@ def main(batch: int = 1, ...@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads, heads,
dim, dim,
is_causal, is_causal,
block_M=64, block_M=128,
block_N=64, block_N=128,
num_stages=1, num_stages=2,
threads=128) threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -255,7 +265,9 @@ def main(batch: int = 1, ...@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench( latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment