Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
271 additions
and
254 deletions
+271
-254
docker/Dockerfile.cu128
docker/Dockerfile.cu128
+5
-2
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+1
-1
docs/compiler_internals/inject_fence_proxy.md
docs/compiler_internals/inject_fence_proxy.md
+3
-3
examples/attention_sink/benchmark_gqa_sink_fwd.py
examples/attention_sink/benchmark_gqa_sink_fwd.py
+3
-2
examples/attention_sink/benchmark_mha_sink_fwd.py
examples/attention_sink/benchmark_mha_sink_fwd.py
+3
-2
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
+9
-15
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
+4
-7
examples/attention_sink/example_mha_sink_bwd_bhsd.py
examples/attention_sink/example_mha_sink_bwd_bhsd.py
+8
-15
examples/attention_sink/example_mha_sink_fwd_bhsd.py
examples/attention_sink/example_mha_sink_fwd_bhsd.py
+4
-7
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
+4
-7
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+5
-5
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+4
-2
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+3
-2
examples/deepseek_v32/fp8_lighting_indexer.py
examples/deepseek_v32/fp8_lighting_indexer.py
+4
-2
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
+12
-12
examples/elementwise/test_example_elementwise.py
examples/elementwise/test_example_elementwise.py
+0
-5
examples/flash_attention/example_gqa_bwd_tma_reduce.py
examples/flash_attention/example_gqa_bwd_tma_reduce.py
+18
-14
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
...ples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
+112
-106
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+53
-41
examples/flash_attention/test_example_flash_attention.py
examples/flash_attention/test_example_flash_attention.py
+16
-4
No files found.
docker/Dockerfile.cu128
View file @
bbbf4207
...
...
@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \
build-essential cmake libedit-dev libxml2-dev cython3
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
docker/Dockerfile.rocm
View file @
bbbf4207
...
...
@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang &&
./install_rocm.sh
"
conda run -n py_3.10 bash -c "cd tilelang &&
USE_ROCM=1 pip install -e . -v
"
RUN conda init bash
...
...
docs/compiler_internals/inject_fence_proxy.md
View file @
bbbf4207
...
...
@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_
wgmma_
descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
...
...
@@ -53,7 +53,7 @@ def kernel():
with
T
.
Kernel
(
1
):
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
T
.
initialize_descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
T
.
initialize_
wgmma_
descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
smem
[
0
]
=
T
.
float16
(
0
)
T
.
ptx_wgmma_ss
(
"float16"
,
...
...
@@ -83,7 +83,7 @@ def kernel():
with
T
.
Kernel
(
1
):
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
T
.
initialize_descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
T
.
initialize_
wgmma_
descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
smem
[
0
]
=
T
.
float16
(
0
)
T
.
fence_proxy_async
()
T
.
ptx_wgmma_ss
(
...
...
examples/attention_sink/benchmark_gqa_sink_fwd.py
View file @
bbbf4207
...
...
@@ -5,6 +5,7 @@ import triton
import
triton.language
as
tl
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
example_gqa_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
typing
import
Optional
@
triton
.
jit
...
...
@@ -94,7 +95,7 @@ def triton_kernel(
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
int
|
None
=
None
)
->
torch
.
Tensor
:
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
_
,
n_heads_kv
,
seq_kv
,
_
=
K
.
shape
BLOCK_M
=
64
...
...
@@ -130,7 +131,7 @@ def main(
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
groups
:
int
=
8
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
):
...
...
examples/attention_sink/benchmark_mha_sink_fwd.py
View file @
bbbf4207
...
...
@@ -5,6 +5,7 @@ import triton
import
triton.language
as
tl
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
example_mha_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
typing
import
Optional
@
triton
.
jit
...
...
@@ -93,7 +94,7 @@ def triton_kernel(
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
int
|
None
=
None
)
->
torch
.
Tensor
:
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
seq_kv
=
K
.
shape
[
2
]
BLOCK_M
=
64
...
...
@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
View file @
bbbf4207
...
...
@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
...
...
@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
loop_ed
[
0
]
=
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
else
:
loop_ed
[
0
]
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
[
0
],
num_stages
=
num_stages
):
loop_ed
=
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
64
,
groups
:
int
=
2
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
...
...
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
View file @
bbbf4207
...
...
@@ -172,14 +172,11 @@ def flashattn(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
...
...
@@ -272,7 +269,7 @@ def main(
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
groups
:
int
=
8
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
):
...
...
examples/attention_sink/example_mha_sink_bwd_bhsd.py
View file @
bbbf4207
...
...
@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
...
...
@@ -267,14 +264,10 @@ def flashattn_bwd(
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
loop_ed
[
0
]
=
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
else
:
loop_ed
[
0
]
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
[
0
],
num_stages
=
num_stages
):
loop_ed
=
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
H
:
int
=
1
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
...
...
examples/attention_sink/example_mha_sink_fwd_bhsd.py
View file @
bbbf4207
...
...
@@ -162,13 +162,10 @@ def flashattn(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
...
...
@@ -253,7 +250,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
View file @
bbbf4207
...
...
@@ -165,14 +165,11 @@ def flashattn(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
if
window_size
is
not
None
:
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
...
...
@@ -263,7 +260,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
bbbf4207
...
...
@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
4096
,
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
...
...
@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
4096
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
bbbf4207
...
...
@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return
x_fp8
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
):
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
if
batch_sizes
is
None
:
batch_sizes
=
[
2048
,
6144
]
if
dtype
==
"float"
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
elif
dtype
==
"float16"
:
...
...
@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
batch_sizes
=
torch
.
tensor
(
[
2048
,
6144
]
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
batch_sizes
=
torch
.
tensor
(
batch_sizes
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
M_max
=
int
(
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
)
print
(
"batch_sizes:"
,
batch_sizes
)
...
...
examples/cast/test_example_cast.py
View file @
bbbf4207
...
...
@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
8192
,
N
=
2048
,
BG
=
2
,
blk_m
=
8
)
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
example_per_token_cast_to_fp8
.
main
(
M
=
8192
,
N
=
2048
,
blk_m
=
8
)
example_per_token_cast_to_fp8
.
main
(
M
=
2048
,
N
=
512
,
blk_m
=
8
)
if
__name__
==
"__main__"
:
...
...
examples/deepseek_v32/fp8_lighting_indexer.py
View file @
bbbf4207
...
...
@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared
=
T
.
alloc_shared
([
block_N
,
index_dim
],
dtype
)
index_k_scale_fragment
=
T
.
alloc_fragment
([
block_N
],
accum_dtype
)
s
=
T
.
alloc_fragment
([
block_N
,
block_Q
*
heads
],
accum_dtype
)
s_reshaped
=
T
.
alloc_fragment
([
block_N
,
block_Q
,
heads
],
accum_dtype
)
s_reshaped
=
T
.
reshape
(
s
,
(
block_N
,
block_Q
,
heads
)
)
logits
=
T
.
alloc_fragment
([
block_N
,
block_Q
],
accum_dtype
)
weights
=
T
.
alloc_fragment
([
block_Q
,
heads
],
accum_dtype
)
...
...
@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for
bn_i
,
bq_i
,
h_i
in
T
.
Parallel
(
block_N
,
block_Q
,
heads
):
s_reshaped
[
bn_i
,
bq_i
,
h_i
]
=
(
T
.
max
(
s
[
bn_i
,
bq_i
*
heads
+
h_i
],
0
)
*
h_i
]
=
(
T
.
max
(
s
_reshaped
[
bn_i
,
bq_i
,
h_i
],
0
)
*
weights
[
bq_i
,
h_i
])
*
index_k_scale_fragment
[
bn_i
]
T
.
reduce_sum
(
s_reshaped
,
logits
,
dim
=-
1
,
clear
=
True
)
...
...
@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def
test_fp8_lighting_indexer
(
S
=
4096
,
SKV
=
8192
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
):
# initial random seed to make the performance reproducible
torch
.
manual_seed
(
0
)
q
=
torch
.
randn
(
S
,
H
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
kv
=
torch
.
randn
(
SKV
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
weights
=
torch
.
randn
(
S
,
H
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
...
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
View file @
bbbf4207
# ruff: noqa
import
tilelang.testing
from
topk_selector
import
test_
topk_selector
from
fp8_lighting_indexer
import
test_
fp8_lighting_indexer
from
sparse_mla_fwd
import
test_
sparse_mla_fwd
from
sparse_mla_fwd_pipelined
import
test_
sparse_mla_fwd_pipelined
from
sparse_mla_bwd
import
test_
sparse_mla_bwd
import
topk_selector
import
fp8_lighting_indexer
import
sparse_mla_fwd
import
sparse_mla_fwd_pipelined
import
sparse_mla_bwd
def
test_example_topk_selector
():
test_topk_selector
()
topk_selector
.
test_topk_selector
()
def
test_example_fp8_lighting_indexer
():
test_fp8_lighting_indexer
(
S
=
1024
,
SKV
=
2048
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
)
fp8_lighting_indexer
.
test_fp8_lighting_indexer
(
S
=
512
,
SKV
=
1024
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_fwd
():
# small shapes for testing
test_sparse_mla_fwd
(
sparse_mla_fwd
.
test_sparse_mla_fwd
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
...
...
@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_fwd_pipelined
():
# small shapes for testing
test_sparse_mla_fwd_pipelined
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
sparse_mla_fwd_pipelined
.
test_sparse_mla_fwd_pipelined
(
S
=
256
,
SKV
=
512
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_bwd
():
test_sparse_mla_bwd
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQKV
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
sparse_mla_bwd
.
test_sparse_mla_bwd
(
S
=
256
,
SKV
=
512
,
H
=
64
,
HKV
=
1
,
DQKV
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
if
__name__
==
"__main__"
:
...
...
examples/elementwise/test_example_elementwise.py
View file @
bbbf4207
import
tilelang.testing
import
example_elementwise_add
import
example_elementwise_add_tma_1d
def
test_example_elementwise_add
():
example_elementwise_add
.
main
()
def
test_example_elementwise_add_tma_1d
():
example_elementwise_add_tma_1d
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/flash_attention/example_gqa_bwd_tma_reduce.py
View file @
bbbf4207
...
...
@@ -5,6 +5,8 @@ import tilelang.language as T
from
tilelang.contrib
import
nvcc
import
argparse
tilelang
.
disable_cache
()
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
...
...
@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T
.
fill
(
scores_max
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
...
...
@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc
_s
.
dtype
))
T
.
Cast
(
acc
um_
dtype
,
-
1e30
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -265,7 +269,7 @@ def flashattn_bwd_atomic_add(batch,
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
def
flashattn_bwd_split
_novarlen
(
batch
,
heads
,
seq_len
,
dim_qk
,
...
...
@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
else
:
kernel
=
flashattn_bwd_split
(
kernel
=
flashattn_bwd_split
_novarlen
(
BATCH
,
H
,
N_CTX
,
...
...
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
View file @
bbbf4207
This diff is collapsed.
Click to expand it.
examples/flash_attention/example_gqa_fwd_varlen.py
View file @
bbbf4207
...
...
@@ -24,21 +24,32 @@ def attention_ref(
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
b
,
T
,
Hq
,
D
=
q
.
shape
S
=
k
.
shape
[
1
]
scale
=
(
1.0
/
D
)
**
0.5
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
left
,
right
=
window_size
left
=
S
if
left
is
None
or
left
<
0
else
int
(
left
)
right
=
S
if
right
is
None
or
right
<
0
else
int
(
right
)
t_idx
=
torch
.
arange
(
T
,
device
=
scores
.
device
)[:,
None
]
s_idx
=
torch
.
arange
(
S
,
device
=
scores
.
device
)[
None
,
:]
visible_ts
=
(
s_idx
>=
(
t_idx
-
left
))
&
(
s_idx
<=
(
t_idx
+
right
))
visible_mask
=
visible_ts
.
unsqueeze
(
0
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
k_keep
=
rearrange
(
key_padding_mask
,
"b s -> b 1 1 s"
)
visible_mask
=
visible_mask
&
k_keep
neg_inf
=
torch
.
finfo
(
scores
.
dtype
).
min
scores
=
scores
*
scale
scores
=
scores
.
masked_fill
(
~
visible_mask
,
neg_inf
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
q_keep
=
rearrange
(
query_padding_mask
,
"b t -> b 1 t 1"
)
attention
=
attention
.
masked_fill
(
~
q_keep
,
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill
_
(
rearrange
(
~
query_padding_mask
,
"b
s
-> b
s
1 1"
),
0.0
)
output
=
output
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b
t
-> b
t
1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
...
...
@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
})
batch_idx
=
bz
head_idx
=
by
kv_head_idx
=
head_idx
//
groups
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
kv_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
kv_current_seqlen
=
k_end_idx
-
kv_start_idx
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
>=
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
q_current_seqlen
+
(
bx
+
1
)
*
block_M
,
block_N
),
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K_unpad
[
k_start_idx
+
k
*
block_N
:
k_start_idx
+
(
k
+
1
)
*
block_N
,
K_unpad
[
k
v
_start_idx
+
k
*
block_N
:
k
v
_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
k_current_seqlen
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
<
k
*
block_N
+
j
)
or
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
k
*
block_N
+
j
>=
k
v
_current_seqlen
),
-
1e9
,
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -145,6 +156,9 @@ def flashattn(batch_size,
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V_unpad
[
v_start_idx
+
k
*
block_N
:
v_start_idx
+
(
k
+
1
)
*
block_N
,
V_unpad
[
k
v_start_idx
+
k
*
block_N
:
k
v_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
v_current_seqlen
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
causal
:
if
is_causal
:
total_flops
*=
0.5
tilelang
.
testing
.
set_random_seed
(
0
)
...
...
@@ -201,9 +211,9 @@ def main(batch: int = 1,
device
=
torch
.
device
(
"cuda"
)
head_kv
=
heads
//
groups
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
...
...
@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
...
...
@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
))
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
),
_n_warmup
=
5
,
_n_repeat
=
5
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
...
...
examples/flash_attention/test_example_flash_attention.py
View file @
bbbf4207
...
...
@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd
():
example_mha_bwd
.
main
(
BATCH
=
1
)
example_mha_bwd
.
main
(
BATCH
=
1
,
H
=
16
,
N_CTX
=
512
,
D_HEAD
=
64
,
causal
=
False
,
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd_bhsd
():
example_mha_bwd_bhsd
.
main
(
BATCH
=
1
)
example_mha_bwd_bhsd
.
main
(
BATCH
=
1
,
H
=
16
,
N_CTX
=
512
,
D_HEAD
=
64
,
causal
=
False
,
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_bwd_wgmma_pipelined
():
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
)
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
,
H
=
32
,
N_CTX
=
256
,
D_HEAD
=
64
,
causal
=
False
)
@
tilelang
.
testing
.
requires_cuda
...
...
@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_fwd_varlen
():
example_mha_fwd_varlen
.
main
()
example_mha_fwd_varlen
.
main
(
batch
=
4
,
heads
=
16
,
seq_len
=
512
,
dim
=
64
)
if
__name__
==
"__main__"
:
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment