Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1 ...@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \
build-essential cmake libedit-dev libxml2-dev cython3
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh && cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash CMD bash
...@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \ ...@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh" conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"
RUN conda init bash RUN conda init bash
......
...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the ...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View ### Timeline View
``` ```
generic initialize_descriptor → generic shared-store → async wgmma generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │ │ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy └─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑ │ fence inserted here ↑
...@@ -53,7 +53,7 @@ def kernel(): ...@@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
"float16", "float16",
...@@ -83,7 +83,7 @@ def kernel(): ...@@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.fence_proxy_async() T.fence_proxy_async()
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
......
...@@ -5,6 +5,7 @@ import triton ...@@ -5,6 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit @triton.jit
...@@ -94,7 +95,7 @@ def triton_kernel( ...@@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape _, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64 BLOCK_M = 64
...@@ -130,7 +131,7 @@ def main( ...@@ -130,7 +131,7 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
......
...@@ -5,6 +5,7 @@ import triton ...@@ -5,6 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit @triton.jit
...@@ -93,7 +94,7 @@ def triton_kernel( ...@@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2] seq_kv = K.shape[2]
BLOCK_M = 64 BLOCK_M = 64
...@@ -125,7 +126,7 @@ def main(batch: int = 1, ...@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -81,13 +81,10 @@ def flashattn_fwd( ...@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -266,14 +263,11 @@ def flashattn_bwd(batch, ...@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -444,7 +438,7 @@ def main(BATCH: int = 1, ...@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 64, D_HEAD: int = 64,
groups: int = 2, groups: int = 2,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16"): dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
......
...@@ -172,14 +172,11 @@ def flashattn( ...@@ -172,14 +172,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
...@@ -272,7 +269,7 @@ def main( ...@@ -272,7 +269,7 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
......
...@@ -78,13 +78,10 @@ def flashattn_fwd( ...@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -267,14 +264,10 @@ def flashattn_bwd( ...@@ -267,14 +264,10 @@ def flashattn_bwd(
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N), for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -440,7 +433,7 @@ def main(BATCH: int = 1, ...@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
H: int = 1, H: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 128, D_HEAD: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16"): dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
......
...@@ -162,13 +162,10 @@ def flashattn( ...@@ -162,13 +162,10 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum) logsum)
...@@ -253,7 +250,7 @@ def main(batch: int = 1, ...@@ -253,7 +250,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -165,14 +165,11 @@ def flashattn( ...@@ -165,14 +165,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
...@@ -263,7 +260,7 @@ def main(batch: int = 1, ...@@ -263,7 +260,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16, batch=8,
heads=16, heads=8,
heads_kv=8, heads_kv=4,
max_cache_seqlen=4096, max_cache_seqlen=2048,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask(): ...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch=16, batch=16,
heads=16, heads=16,
heads_kv=8, heads_kv=8,
max_cache_seqlen=4096, max_cache_seqlen=1024,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
......
...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8 return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8): def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float": if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == "float16":
...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8): ...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32) batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128) M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes) print("batch_sizes:", batch_sizes)
......
...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8 ...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8) example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8) example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -127,7 +127,7 @@ def mqa_attn_return_logits( ...@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) s_reshaped = T.reshape(s, (block_N, block_Q, heads))
logits = T.alloc_fragment([block_N, block_Q], accum_dtype) logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype)
...@@ -165,7 +165,7 @@ def mqa_attn_return_logits( ...@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, ...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
# initial random seed to make the performance reproducible
torch.manual_seed(0)
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32) weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
......
# ruff: noqa # ruff: noqa
import tilelang.testing import tilelang.testing
from topk_selector import test_topk_selector import topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer import fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd import sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined import sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd import sparse_mla_bwd
def test_example_topk_selector(): def test_example_topk_selector():
test_topk_selector() topk_selector.test_topk_selector()
def test_example_fp8_lighting_indexer(): def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1) fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd( sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
...@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd(): ...@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd_pipelined( sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
test_sparse_mla_bwd( sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__": if __name__ == "__main__":
......
import tilelang.testing import tilelang.testing
import example_elementwise_add import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add(): def test_example_elementwise_add():
example_elementwise_add.main() example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -5,6 +5,8 @@ import tilelang.language as T ...@@ -5,6 +5,8 @@ import tilelang.language as T
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
import argparse import argparse
tilelang.disable_cache()
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4], pass_configs={
...@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = ( loop_range = (
T.ceildiv( T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
...@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) T.Cast(accum_dtype, -1e30))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -265,7 +269,7 @@ def flashattn_bwd_atomic_add(batch, ...@@ -265,7 +269,7 @@ def flashattn_bwd_atomic_add(batch,
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd_split(batch, def flashattn_bwd_split_novarlen(batch,
heads, heads,
seq_len, seq_len,
dim_qk, dim_qk,
...@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function): ...@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv)
else: else:
kernel = flashattn_bwd_split( kernel = flashattn_bwd_split_novarlen(
BATCH, BATCH,
H, H,
N_CTX, N_CTX,
......
...@@ -7,6 +7,8 @@ import argparse ...@@ -7,6 +7,8 @@ import argparse
from einops import rearrange, repeat from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input from bert_padding import pad_input, unpad_input
# tilelang.disable_cache()
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"] assert mode in ["full", "random", "third"]
...@@ -29,6 +31,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): ...@@ -29,6 +31,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
def flashattn_fwd(batch, def flashattn_fwd(batch,
total_q, total_q,
total_kv, total_kv,
N_CTX,
heads, heads,
max_seq_len, max_seq_len,
dim_qk, dim_qk,
...@@ -54,7 +57,7 @@ def flashattn_fwd(batch, ...@@ -54,7 +57,7 @@ def flashattn_fwd(batch,
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -86,7 +89,9 @@ def flashattn_fwd(batch, ...@@ -86,7 +89,9 @@ def flashattn_fwd(batch,
T.fill(acc_o, 0.0) T.fill(acc_o, 0.0)
T.fill(logsum, 0.0) T.fill(logsum, 0.0)
T.fill(scores_max, -T.infinity(accum_dtype)) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = T.ceildiv(k_current_seqlen, block_N) loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
for i, d in T.Parallel(block_N, dim_qk): for i, d in T.Parallel(block_N, dim_qk):
...@@ -100,12 +105,12 @@ def flashattn_fwd(batch, ...@@ -100,12 +105,12 @@ def flashattn_fwd(batch,
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i < q_current_seqlen and (bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen), 0, k * block_N + j < k_current_seqlen), 0,
-T.infinity(acc_s.dtype)) T.Cast(accum_dtype, -1e30))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i, j] = T.if_then_else(
bx * block_M + i < q_current_seqlen and bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype)) k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v): for i, d in T.Parallel(block_N, dim_v):
if k * block_N + i < k_current_seqlen: if k * block_N + i < k_current_seqlen:
...@@ -135,7 +140,7 @@ def flashattn_fwd(batch, ...@@ -135,7 +140,7 @@ def flashattn_fwd(batch,
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
if bx * block_M + i < q_current_seqlen: if bx * block_M + i < q_current_seqlen:
lse[q_start_idx + bx * block_M + i, by] = logsum[i] lse[bz, by, bx * block_M + i] = logsum[i]
return flash_fwd return flash_fwd
...@@ -144,7 +149,7 @@ def flashattn_fwd(batch, ...@@ -144,7 +149,7 @@ def flashattn_fwd(batch,
out_idx=[3], pass_configs={ out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
shape = [total_q, heads, dim_v] shape = [total_q, heads, dim_v]
...@@ -155,7 +160,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v): ...@@ -155,7 +160,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -183,14 +188,14 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v): ...@@ -183,14 +188,14 @@ def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
for i in T.Parallel(blk): for i in T.Parallel(blk):
if by * blk + i < q_current_seqlen: if by * blk + i < q_current_seqlen:
Delta[q_start_idx + by * blk + i, bx] = delta[i] Delta[bz, bx, by * blk + i] = delta[i]
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction # bshd -> bhsd to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d]) return T.Layout(dQ.shape, lambda l, h, d: [h, l, d])
@tilelang.jit( @tilelang.jit(
...@@ -215,13 +220,13 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): ...@@ -215,13 +220,13 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dV_out: T.Tensor(v_shape, dtype), # type: ignore dV_out: T.Tensor(v_shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by):
# T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by):
# T.annotate_layout({ T.annotate_layout({
# dK: make_dq_layout(dK), dK: make_dq_layout(dK),
# dV: make_dq_layout(dV), dV: make_dq_layout(dV),
# }) })
T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :])
T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :])
...@@ -234,6 +239,7 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): ...@@ -234,6 +239,7 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
def flashattn_bwd_atomic_add(batch, def flashattn_bwd_atomic_add(batch,
total_q, total_q,
total_kv, total_kv,
N_CTX,
heads, heads,
max_seq_len, max_seq_len,
dim_qk, dim_qk,
...@@ -260,8 +266,8 @@ def flashattn_bwd_atomic_add(batch, ...@@ -260,8 +266,8 @@ def flashattn_bwd_atomic_add(batch,
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
...@@ -284,6 +290,9 @@ def flashattn_bwd_atomic_add(batch, ...@@ -284,6 +290,9 @@ def flashattn_bwd_atomic_add(batch,
dv = T.alloc_fragment([block_M, dim_v], accum_dtype) dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
q_start_idx = cu_seqlens_q[bz] q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz] k_start_idx = cu_seqlens_k[bz]
...@@ -293,39 +302,32 @@ def flashattn_bwd_atomic_add(batch, ...@@ -293,39 +302,32 @@ def flashattn_bwd_atomic_add(batch,
k_current_seqlen = k_end_idx - k_start_idx k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({ T.annotate_layout({
# dQ: make_dq_layout(dQ), dQ: make_dq_layout(dQ),
# dK: make_dq_layout(dK), dK: make_dq_layout(dK),
# dV: make_dq_layout(dV), dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}) })
for i, d in T.Parallel(block_M, dim_qk): T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
if by * block_M + i < k_current_seqlen: K_shared)
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d] T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d] V_shared)
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0) loop_st = T.min(
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N) loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk): T.copy(
if k_base * block_N + i < q_current_seqlen: Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d] q)
else:
q[i, d] = 0.0
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N): T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared)
if k_base * block_N + i < q_current_seqlen:
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
...@@ -341,22 +343,16 @@ def flashattn_bwd_atomic_add(batch, ...@@ -341,22 +343,16 @@ def flashattn_bwd_atomic_add(batch,
by * block_M + i < k_current_seqlen and by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
for i, d in T.Parallel(block_N, dim_v): T.copy(
if k_base * block_N + i < q_current_seqlen: dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d] do)
else:
do[i, d] = 0.0
T.clear(dsT) T.clear(dsT)
# dsT: (block_kv, block_q) # dsT: (block_kv, block_q)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N): T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta)
if k_base * block_N + i < q_current_seqlen:
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
...@@ -364,22 +360,28 @@ def flashattn_bwd_atomic_add(batch, ...@@ -364,22 +360,28 @@ def flashattn_bwd_atomic_add(batch,
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add( T.atomic_add(
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
bx, :], bx, :],
dq, dq_shared,
memory_order="release") memory_order="relaxed",
use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add( T.atomic_add(
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :], bx // groups, :],
dv, dv_shared,
memory_order="release") memory_order="relaxed",
use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add( T.atomic_add(
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :], bx // groups, :],
dk, dk_shared,
memory_order="release") memory_order="relaxed",
use_tma=True)
return flash_bwd return flash_bwd
...@@ -390,6 +392,7 @@ def flashattn_bwd_atomic_add(batch, ...@@ -390,6 +392,7 @@ def flashattn_bwd_atomic_add(batch,
def flashattn_bwd_split(batch, def flashattn_bwd_split(batch,
total_q, total_q,
total_kv, total_kv,
N_CTX,
heads, heads,
max_seq_len, max_seq_len,
dim_qk, dim_qk,
...@@ -418,8 +421,8 @@ def flashattn_bwd_split(batch, ...@@ -418,8 +421,8 @@ def flashattn_bwd_split(batch,
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
...@@ -453,46 +456,41 @@ def flashattn_bwd_split(batch, ...@@ -453,46 +456,41 @@ def flashattn_bwd_split(batch,
k_current_seqlen = k_end_idx - k_start_idx k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({ T.annotate_layout({
# dQ: make_dq_layout(dQ), dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}) })
for i, d in T.Parallel(block_M, dim_qk): T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
if by * block_M + i < k_current_seqlen: K_shared)
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d] T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d] V_shared)
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0) loop_st = T.min(
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N) loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk): # Note: The padding zero of varlen should be considered in T.copy
if k_base * block_N + i < q_current_seqlen: T.copy(
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d] Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
else: q)
q[i, d] = 0.0
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k_base * block_N + i < q_current_seqlen: T.copy(
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d] dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
else: do)
do[i, d] = 0.0
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen: T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared)
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
...@@ -508,11 +506,8 @@ def flashattn_bwd_split(batch, ...@@ -508,11 +506,8 @@ def flashattn_bwd_split(batch,
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen: T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta)
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -526,16 +521,18 @@ def flashattn_bwd_split(batch, ...@@ -526,16 +521,18 @@ def flashattn_bwd_split(batch,
T.atomic_add( T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, j], dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j], dq[i, j],
memory_order="release") memory_order="relaxed")
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
for i, d in T.Parallel(block_M, dim_v): T.copy(
if by * block_M + i < k_current_seqlen: dv_shared,
dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d] dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
for i, d in T.Parallel(block_M, dim_qk): T.copy(
if by * block_M + i < k_current_seqlen: dk_shared,
dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d] dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
return flash_bwd return flash_bwd
...@@ -571,12 +568,13 @@ class _attention(torch.autograd.Function): ...@@ -571,12 +568,13 @@ class _attention(torch.autograd.Function):
total_q = q_unpad.shape[0] total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0] total_kv = k_unpad.shape[0]
mod = flashattn_fwd(BATCH, total_q, total_kv, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V,
block_M, block_N, groups) causal, block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX) o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k,
cu_seqlens_q, cu_seqlens_k) cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic ctx.use_atomic = use_atomic
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
...@@ -588,7 +586,8 @@ class _attention(torch.autograd.Function): ...@@ -588,7 +586,8 @@ class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
N_CTX = do.shape[1] N_CTX = do.shape[1]
q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
# lse_clone = lse.clone()
do_unpad, _, _, _ = unpad_input( do_unpad, _, _, _ = unpad_input(
do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape total_q, H, D_HEAD_QK = q.shape
...@@ -604,7 +603,7 @@ class _attention(torch.autograd.Function): ...@@ -604,7 +603,7 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
block_M = 128 block_M = 128
block_N = 32 block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, ctx.max_seqlen_q, D_HEAD_V) mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q) delta = mod_prep(o, do, cu_seqlens_q)
...@@ -613,6 +612,7 @@ class _attention(torch.autograd.Function): ...@@ -613,6 +612,7 @@ class _attention(torch.autograd.Function):
BATCH, BATCH,
total_q, total_q,
total_kv, total_kv,
N_CTX,
H, H,
ctx.max_seqlen_q, ctx.max_seqlen_q,
D_HEAD_QK, D_HEAD_QK,
...@@ -626,13 +626,14 @@ class _attention(torch.autograd.Function): ...@@ -626,13 +626,14 @@ class _attention(torch.autograd.Function):
dq = torch.zeros_like(q, dtype=torch.float32) dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv)
else: else:
kernel = flashattn_bwd_split( kernel = flashattn_bwd_split(
BATCH, BATCH,
total_q, total_q,
total_kv, total_kv,
N_CTX,
H, H,
ctx.max_seqlen_q, ctx.max_seqlen_q,
D_HEAD_QK, D_HEAD_QK,
...@@ -646,7 +647,7 @@ class _attention(torch.autograd.Function): ...@@ -646,7 +647,7 @@ class _attention(torch.autograd.Function):
dq = torch.zeros_like(q, dtype=torch.float32) dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32)) torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0) dk, dv = dk.sum(0), dv.sum(0)
...@@ -739,12 +740,6 @@ def main(BATCH: int = 1, ...@@ -739,12 +740,6 @@ def main(BATCH: int = 1,
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -760,6 +755,15 @@ def main(BATCH: int = 1, ...@@ -760,6 +755,15 @@ def main(BATCH: int = 1,
print("tilelang: {:.2f} ms".format(latency)) print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
print(
"Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark."
)
if __name__ == "__main__": if __name__ == "__main__":
arch = nvcc.get_target_compute_version() arch = nvcc.get_target_compute_version()
...@@ -778,6 +782,8 @@ if __name__ == "__main__": ...@@ -778,6 +782,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV') '--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
# Can be set to True/False for testing
args.causal = True
# Handle backward compatibility and logic # Handle backward compatibility and logic
if args.use_split: if args.use_split:
...@@ -785,8 +791,8 @@ if __name__ == "__main__": ...@@ -785,8 +791,8 @@ if __name__ == "__main__":
elif args.use_atomic: elif args.use_atomic:
use_atomic = True use_atomic = True
else: else:
# Default: use split # Default: use atomic
use_atomic = False use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic) use_atomic)
...@@ -24,21 +24,32 @@ def attention_ref( ...@@ -24,21 +24,32 @@ def attention_ref(
dtype_og = q.dtype dtype_og = q.dtype
if upcast: if upcast:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1] b, T, Hq, D = q.shape
scale = (1.0 / dim)**0.5 S = k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) scale = (1.0 / D)**0.5
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k) scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype) attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None: if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v) output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
...@@ -91,53 +102,53 @@ def flashattn(batch_size, ...@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
})
batch_idx = bz batch_idx = bz
head_idx = by head_idx = by
kv_head_idx = head_idx // groups kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx] q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx] kv_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1] q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1] k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx kv_current_seqlen = k_end_idx - kv_start_idx
v_current_seqlen = v_end_idx - v_start_idx
T.copy( T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared) Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N) loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared) kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i,
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
(bx * block_M + i >= q_current_seqlen or (bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen), k * block_N + j >= kv_current_seqlen), -1e9, 0)
-T.infinity(acc_s.dtype), 0)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen), k * block_N + j >= kv_current_seqlen), -1e9,
-T.infinity(acc_s.dtype), 0) 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -145,6 +156,9 @@ def flashattn(batch_size, ...@@ -145,6 +156,9 @@ def flashattn(batch_size,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -158,11 +172,8 @@ def flashattn(batch_size, ...@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared) kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
...@@ -191,8 +202,7 @@ def main(batch: int = 1, ...@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
causal = False if is_causal:
if causal:
total_flops *= 0.5 total_flops *= 0.5
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -201,9 +211,9 @@ def main(batch: int = 1, ...@@ -201,9 +211,9 @@ def main(batch: int = 1,
device = torch.device("cuda") device = torch.device("cuda")
head_kv = heads // groups head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
...@@ -236,10 +246,10 @@ def main(batch: int = 1, ...@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads, heads,
dim, dim,
is_causal, is_causal,
block_M=64, block_M=128,
block_N=64, block_N=128,
num_stages=1, num_stages=2,
threads=128) threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -255,7 +265,9 @@ def main(batch: int = 1, ...@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench( latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......
...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined(): ...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd(): def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1) example_mha_bwd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd(): def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1) example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined(): def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1) example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd(): ...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen(): def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main() example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment