"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "3274ca3094bf05d4fb9d6afa554a2bd71001b2d8"
Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1 ...@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \
build-essential cmake libedit-dev libxml2-dev cython3
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh && cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash CMD bash
...@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \ ...@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh" conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"
RUN conda init bash RUN conda init bash
......
...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the ...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View ### Timeline View
``` ```
generic initialize_descriptor → generic shared-store → async wgmma generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │ │ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy └─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑ │ fence inserted here ↑
...@@ -53,7 +53,7 @@ def kernel(): ...@@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
"float16", "float16",
...@@ -83,7 +83,7 @@ def kernel(): ...@@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.fence_proxy_async() T.fence_proxy_async()
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
......
...@@ -5,6 +5,7 @@ import triton ...@@ -5,6 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit @triton.jit
...@@ -94,7 +95,7 @@ def triton_kernel( ...@@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape _, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64 BLOCK_M = 64
...@@ -130,7 +131,7 @@ def main( ...@@ -130,7 +131,7 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
......
...@@ -5,6 +5,7 @@ import triton ...@@ -5,6 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit @triton.jit
...@@ -93,7 +94,7 @@ def triton_kernel( ...@@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2] seq_kv = K.shape[2]
BLOCK_M = 64 BLOCK_M = 64
...@@ -125,7 +126,7 @@ def main(batch: int = 1, ...@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -81,13 +81,10 @@ def flashattn_fwd( ...@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -266,14 +263,11 @@ def flashattn_bwd(batch, ...@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -444,7 +438,7 @@ def main(BATCH: int = 1, ...@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 64, D_HEAD: int = 64,
groups: int = 2, groups: int = 2,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16"): dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
......
...@@ -172,14 +172,11 @@ def flashattn( ...@@ -172,14 +172,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
...@@ -272,7 +269,7 @@ def main( ...@@ -272,7 +269,7 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
......
...@@ -78,13 +78,10 @@ def flashattn_fwd( ...@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -267,14 +264,10 @@ def flashattn_bwd( ...@@ -267,14 +264,10 @@ def flashattn_bwd(
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N), for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -440,7 +433,7 @@ def main(BATCH: int = 1, ...@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
H: int = 1, H: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 128, D_HEAD: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16"): dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
......
...@@ -162,13 +162,10 @@ def flashattn( ...@@ -162,13 +162,10 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum) logsum)
...@@ -253,7 +250,7 @@ def main(batch: int = 1, ...@@ -253,7 +250,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -165,14 +165,11 @@ def flashattn( ...@@ -165,14 +165,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
...@@ -263,7 +260,7 @@ def main(batch: int = 1, ...@@ -263,7 +260,7 @@ def main(batch: int = 1,
seq_q: int = 256, seq_q: int = 256,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16, batch=8,
heads=16, heads=8,
heads_kv=8, heads_kv=4,
max_cache_seqlen=4096, max_cache_seqlen=2048,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask(): ...@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch=16, batch=16,
heads=16, heads=16,
heads_kv=8, heads_kv=8,
max_cache_seqlen=4096, max_cache_seqlen=1024,
dim=128, dim=128,
dim_v=128, dim_v=128,
sparse_ratio=0.8, sparse_ratio=0.8,
......
...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8 return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8): def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float": if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == "float16":
...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8): ...@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32) batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128) M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes) print("batch_sizes:", batch_sizes)
......
...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8 ...@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8) example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8) example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -127,7 +127,7 @@ def mqa_attn_return_logits( ...@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) s_reshaped = T.reshape(s, (block_N, block_Q, heads))
logits = T.alloc_fragment([block_N, block_Q], accum_dtype) logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype)
...@@ -165,7 +165,7 @@ def mqa_attn_return_logits( ...@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, ...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
# initial random seed to make the performance reproducible
torch.manual_seed(0)
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32) weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
......
# ruff: noqa # ruff: noqa
import tilelang.testing import tilelang.testing
from topk_selector import test_topk_selector import topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer import fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd import sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined import sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd import sparse_mla_bwd
def test_example_topk_selector(): def test_example_topk_selector():
test_topk_selector() topk_selector.test_topk_selector()
def test_example_fp8_lighting_indexer(): def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1) fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd( sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
...@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd(): ...@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd_pipelined( sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
test_sparse_mla_bwd( sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__": if __name__ == "__main__":
......
import tilelang.testing import tilelang.testing
import example_elementwise_add import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add(): def test_example_elementwise_add():
example_elementwise_add.main() example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -5,6 +5,8 @@ import tilelang.language as T ...@@ -5,6 +5,8 @@ import tilelang.language as T
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
import argparse import argparse
tilelang.disable_cache()
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4], pass_configs={
...@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = ( loop_range = (
T.ceildiv( T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
...@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) T.Cast(accum_dtype, -1e30))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -265,17 +269,17 @@ def flashattn_bwd_atomic_add(batch, ...@@ -265,17 +269,17 @@ def flashattn_bwd_atomic_add(batch,
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd_split(batch, def flashattn_bwd_split_novarlen(batch,
heads, heads,
seq_len, seq_len,
dim_qk, dim_qk,
dim_v, dim_v,
is_causal, is_causal,
block_M, block_M,
block_N, block_N,
threads=256, threads=256,
num_stages=2, num_stages=2,
groups=1): groups=1):
sm_scale = (1.0 / dim_qk)**0.5 sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function): ...@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv)
else: else:
kernel = flashattn_bwd_split( kernel = flashattn_bwd_split_novarlen(
BATCH, BATCH,
H, H,
N_CTX, N_CTX,
......
...@@ -24,21 +24,32 @@ def attention_ref( ...@@ -24,21 +24,32 @@ def attention_ref(
dtype_og = q.dtype dtype_og = q.dtype
if upcast: if upcast:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1] b, T, Hq, D = q.shape
scale = (1.0 / dim)**0.5 S = k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) scale = (1.0 / D)**0.5
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k) scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype) attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None: if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v) output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
...@@ -91,53 +102,53 @@ def flashattn(batch_size, ...@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
})
batch_idx = bz batch_idx = bz
head_idx = by head_idx = by
kv_head_idx = head_idx // groups kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx] q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx] kv_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1] q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1] k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx kv_current_seqlen = k_end_idx - kv_start_idx
v_current_seqlen = v_end_idx - v_start_idx
T.copy( T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared) Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N) loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared) kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i,
(bx * block_M + i >= q_current_seqlen or j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
k * block_N + j >= k_current_seqlen), (bx * block_M + i >= q_current_seqlen or
-T.infinity(acc_s.dtype), 0) k * block_N + j >= kv_current_seqlen), -1e9, 0)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen), k * block_N + j >= kv_current_seqlen), -1e9,
-T.infinity(acc_s.dtype), 0) 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -145,6 +156,9 @@ def flashattn(batch_size, ...@@ -145,6 +156,9 @@ def flashattn(batch_size,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -158,11 +172,8 @@ def flashattn(batch_size, ...@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared) kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
...@@ -191,8 +202,7 @@ def main(batch: int = 1, ...@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
causal = False if is_causal:
if causal:
total_flops *= 0.5 total_flops *= 0.5
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -201,9 +211,9 @@ def main(batch: int = 1, ...@@ -201,9 +211,9 @@ def main(batch: int = 1,
device = torch.device("cuda") device = torch.device("cuda")
head_kv = heads // groups head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
...@@ -236,10 +246,10 @@ def main(batch: int = 1, ...@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads, heads,
dim, dim,
is_causal, is_causal,
block_M=64, block_M=128,
block_N=64, block_N=128,
num_stages=1, num_stages=2,
threads=128) threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -255,7 +265,9 @@ def main(batch: int = 1, ...@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench( latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......
...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined(): ...@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd(): def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1) example_mha_bwd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd(): def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1) example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined(): def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1) example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd(): ...@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen(): def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main() example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment