Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \
build-essential cmake libedit-dev libxml2-dev cython3
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh"
conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"
RUN conda init bash
......
......@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
......@@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
......@@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
......
......@@ -5,6 +5,7 @@ import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
......@@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
......@@ -130,7 +131,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
......
......@@ -5,6 +5,7 @@ import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
......@@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
......@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
......@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
......
......@@ -172,14 +172,11 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
......@@ -272,7 +269,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
......
......@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
......@@ -267,14 +264,10 @@ def flashattn_bwd(
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
......
......@@ -162,13 +162,10 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
......@@ -253,7 +250,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -165,14 +165,11 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
......@@ -263,7 +260,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
batch=8,
heads=8,
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
......@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
......
......@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8):
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
......@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32)
batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes)
......
......@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
if __name__ == "__main__":
......
......@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
s_reshaped = T.reshape(s, (block_N, block_Q, heads))
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)
......@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
......@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
# initial random seed to make the performance reproducible
torch.manual_seed(0)
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
......
# ruff: noqa
import tilelang.testing
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd
import topk_selector
import fp8_lighting_indexer
import sparse_mla_fwd
import sparse_mla_fwd_pipelined
import sparse_mla_bwd
def test_example_topk_selector():
test_topk_selector()
topk_selector.test_topk_selector()
def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(
sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
......@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__":
......
import tilelang.testing
import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add():
example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,6 +5,8 @@ import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
tilelang.disable_cache()
@tilelang.jit(
out_idx=[3, 4], pass_configs={
......@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
......@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
T.Cast(accum_dtype, -1e30))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -265,17 +269,17 @@ def flashattn_bwd_atomic_add(batch,
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
def flashattn_bwd_split_novarlen(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
......@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
kernel = flashattn_bwd_split_novarlen(
BATCH,
H,
N_CTX,
......
......@@ -7,6 +7,8 @@ import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
# tilelang.disable_cache()
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
......@@ -29,6 +31,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
def flashattn_fwd(batch,
total_q,
total_kv,
N_CTX,
heads,
max_seq_len,
dim_qk,
......@@ -54,7 +57,7 @@ def flashattn_fwd(batch,
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
......@@ -86,7 +89,9 @@ def flashattn_fwd(batch,
T.fill(acc_o, 0.0)
T.fill(logsum, 0.0)
T.fill(scores_max, -T.infinity(accum_dtype))
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
for i, d in T.Parallel(block_N, dim_qk):
......@@ -100,12 +105,12 @@ def flashattn_fwd(batch,
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen), 0,
-T.infinity(acc_s.dtype))
T.Cast(accum_dtype, -1e30))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype))
k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k * block_N + i < k_current_seqlen:
......@@ -135,7 +140,7 @@ def flashattn_fwd(batch,
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
if bx * block_M + i < q_current_seqlen:
lse[q_start_idx + bx * block_M + i, by] = logsum[i]
lse[bz, by, bx * block_M + i] = logsum[i]
return flash_fwd
......@@ -144,7 +149,7 @@ def flashattn_fwd(batch,
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [total_q, heads, dim_v]
......@@ -155,7 +160,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
......@@ -183,14 +188,14 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
for i in T.Parallel(blk):
if by * blk + i < q_current_seqlen:
Delta[q_start_idx + by * blk + i, bx] = delta[i]
Delta[bz, bx, by * blk + i] = delta[i]
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
# bshd -> bhsd to use tma reduction instruction
return T.Layout(dQ.shape, lambda l, h, d: [h, l, d])
@tilelang.jit(
......@@ -215,13 +220,13 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by):
# T.annotate_layout({dQ: make_dq_layout(dQ)})
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by):
# T.annotate_layout({
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
# })
T.annotate_layout({
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
})
T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :])
T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :])
......@@ -234,6 +239,7 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
def flashattn_bwd_atomic_add(batch,
total_q,
total_kv,
N_CTX,
heads,
max_seq_len,
dim_qk,
......@@ -260,8 +266,8 @@ def flashattn_bwd_atomic_add(batch,
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
......@@ -284,6 +290,9 @@ def flashattn_bwd_atomic_add(batch,
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
......@@ -293,39 +302,32 @@ def flashattn_bwd_atomic_add(batch,
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({
# dQ: make_dq_layout(dQ),
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d]
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d]
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
K_shared)
T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared)
T.clear(dv)
T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0)
loop_st = T.min(
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d]
else:
q[i, d] = 0.0
T.copy(
Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
......@@ -341,22 +343,16 @@ def flashattn_bwd_atomic_add(batch,
by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
for i, d in T.Parallel(block_N, dim_v):
if k_base * block_N + i < q_current_seqlen:
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d]
else:
do[i, d] = 0.0
T.copy(
dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
do)
T.clear(dsT)
# dsT: (block_kv, block_q)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
......@@ -364,22 +360,28 @@ def flashattn_bwd_atomic_add(batch,
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
bx, :],
dq,
memory_order="release")
dq_shared,
memory_order="relaxed",
use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add(
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dv,
memory_order="release")
dv_shared,
memory_order="relaxed",
use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add(
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dk,
memory_order="release")
dk_shared,
memory_order="relaxed",
use_tma=True)
return flash_bwd
......@@ -390,6 +392,7 @@ def flashattn_bwd_atomic_add(batch,
def flashattn_bwd_split(batch,
total_q,
total_kv,
N_CTX,
heads,
max_seq_len,
dim_qk,
......@@ -418,8 +421,8 @@ def flashattn_bwd_split(batch,
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
......@@ -453,46 +456,41 @@ def flashattn_bwd_split(batch,
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({
# dQ: make_dq_layout(dQ),
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d]
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d]
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
K_shared)
T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared)
T.clear(dv)
T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0)
loop_st = T.min(
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d]
else:
q[i, d] = 0.0
# Note: The padding zero of varlen should be considered in T.copy
T.copy(
Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k_base * block_N + i < q_current_seqlen:
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d]
else:
do[i, d] = 0.0
T.copy(
dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
......@@ -508,11 +506,8 @@ def flashattn_bwd_split(batch,
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
......@@ -526,16 +521,18 @@ def flashattn_bwd_split(batch,
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j],
memory_order="release")
memory_order="relaxed")
T.copy(dv, dv_shared)
for i, d in T.Parallel(block_M, dim_v):
if by * block_M + i < k_current_seqlen:
dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d]
T.copy(
dv_shared,
dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
T.copy(dk, dk_shared)
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d]
T.copy(
dk_shared,
dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
return flash_bwd
......@@ -571,12 +568,13 @@ class _attention(torch.autograd.Function):
total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0]
mod = flashattn_fwd(BATCH, total_q, total_kv, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal,
block_M, block_N, groups)
mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V,
causal, block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k,
cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH
ctx.causal = causal
ctx.use_atomic = use_atomic
ctx.max_seqlen_q = max_seqlen_q
......@@ -588,7 +586,8 @@ class _attention(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
N_CTX = do.shape[1]
q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
# lse_clone = lse.clone()
do_unpad, _, _, _ = unpad_input(
do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape
......@@ -604,7 +603,7 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, ctx.max_seqlen_q, D_HEAD_V)
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q)
......@@ -613,6 +612,7 @@ class _attention(torch.autograd.Function):
BATCH,
total_q,
total_kv,
N_CTX,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
......@@ -626,13 +626,14 @@ class _attention(torch.autograd.Function):
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
total_q,
total_kv,
N_CTX,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
......@@ -646,7 +647,7 @@ class _attention(torch.autograd.Function):
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
......@@ -739,12 +740,6 @@ def main(BATCH: int = 1,
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
......@@ -760,6 +755,15 @@ def main(BATCH: int = 1,
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
print(
"Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark."
)
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
......@@ -778,6 +782,8 @@ if __name__ == "__main__":
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args()
# Can be set to True/False for testing
args.causal = True
# Handle backward compatibility and logic
if args.use_split:
......@@ -785,8 +791,8 @@ if __name__ == "__main__":
elif args.use_atomic:
use_atomic = True
else:
# Default: use split
use_atomic = False
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
......@@ -24,21 +24,32 @@ def attention_ref(
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim)**0.5
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
b, T, Hq, D = q.shape
S = k.shape[1]
scale = (1.0 / D)**0.5
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
......@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
})
batch_idx = bz
head_idx = by
kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
kv_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
kv_current_seqlen = k_end_idx - kv_start_idx
T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N)
loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N,
K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
acc_s[i,
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= kv_current_seqlen), -1e9, 0)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
k * block_N + j >= kv_current_seqlen), -1e9,
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -145,6 +156,9 @@ def flashattn(batch_size,
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o[i, j] *= scores_scale[i]
T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N,
V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
......@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang.testing.set_random_seed(0)
causal = False
if causal:
if is_causal:
total_flops *= 0.5
tilelang.testing.set_random_seed(0)
......@@ -201,9 +211,9 @@ def main(batch: int = 1,
device = torch.device("cuda")
head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
......@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
block_M=128,
block_N=128,
num_stages=2,
threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
......@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......
......@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1)
example_mha_bwd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1)
example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1)
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda
......@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main()
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment