Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
"...composable_kernel_onnxruntime.git" did not exist on "82c8b9f8eeffc1b9a72dc5a84137ece88e8d5941"
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
271 additions
and
254 deletions
+271
-254
docker/Dockerfile.cu128
docker/Dockerfile.cu128
+5
-2
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+1
-1
docs/compiler_internals/inject_fence_proxy.md
docs/compiler_internals/inject_fence_proxy.md
+3
-3
examples/attention_sink/benchmark_gqa_sink_fwd.py
examples/attention_sink/benchmark_gqa_sink_fwd.py
+3
-2
examples/attention_sink/benchmark_mha_sink_fwd.py
examples/attention_sink/benchmark_mha_sink_fwd.py
+3
-2
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
+9
-15
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
+4
-7
examples/attention_sink/example_mha_sink_bwd_bhsd.py
examples/attention_sink/example_mha_sink_bwd_bhsd.py
+8
-15
examples/attention_sink/example_mha_sink_fwd_bhsd.py
examples/attention_sink/example_mha_sink_fwd_bhsd.py
+4
-7
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
+4
-7
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+5
-5
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+4
-2
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+3
-2
examples/deepseek_v32/fp8_lighting_indexer.py
examples/deepseek_v32/fp8_lighting_indexer.py
+4
-2
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
+12
-12
examples/elementwise/test_example_elementwise.py
examples/elementwise/test_example_elementwise.py
+0
-5
examples/flash_attention/example_gqa_bwd_tma_reduce.py
examples/flash_attention/example_gqa_bwd_tma_reduce.py
+18
-14
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
...ples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
+112
-106
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+53
-41
examples/flash_attention/test_example_flash_attention.py
examples/flash_attention/test_example_flash_attention.py
+16
-4
No files found.
docker/Dockerfile.cu128
View file @
bbbf4207
...
@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
...
@@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \
build-essential cmake libedit-dev libxml2-dev cython3
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.rocm
View file @
bbbf4207
...
@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
...
@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang &&
./install_rocm.sh
"
conda run -n py_3.10 bash -c "cd tilelang &&
USE_ROCM=1 pip install -e . -v
"
RUN conda init bash
RUN conda init bash
...
...
docs/compiler_internals/inject_fence_proxy.md
View file @
bbbf4207
...
@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
...
@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View
### Timeline View
```
```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_
wgmma_
descriptor → generic shared-store → async wgmma
│ │ │
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
│ fence inserted here ↑
...
@@ -53,7 +53,7 @@ def kernel():
...
@@ -53,7 +53,7 @@ def kernel():
with
T
.
Kernel
(
1
):
with
T
.
Kernel
(
1
):
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
T
.
initialize_descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
T
.
initialize_
wgmma_
descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
smem
[
0
]
=
T
.
float16
(
0
)
smem
[
0
]
=
T
.
float16
(
0
)
T
.
ptx_wgmma_ss
(
T
.
ptx_wgmma_ss
(
"float16"
,
"float16"
,
...
@@ -83,7 +83,7 @@ def kernel():
...
@@ -83,7 +83,7 @@ def kernel():
with
T
.
Kernel
(
1
):
with
T
.
Kernel
(
1
):
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
desc
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
smem
=
T
.
decl_buffer
((
128
,),
"float16"
,
scope
=
"shared"
)
T
.
initialize_descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
T
.
initialize_
wgmma_
descriptor
(
desc
,
T
.
uint64
(
0
),
2
,
1
,
32
)
smem
[
0
]
=
T
.
float16
(
0
)
smem
[
0
]
=
T
.
float16
(
0
)
T
.
fence_proxy_async
()
T
.
fence_proxy_async
()
T
.
ptx_wgmma_ss
(
T
.
ptx_wgmma_ss
(
...
...
examples/attention_sink/benchmark_gqa_sink_fwd.py
View file @
bbbf4207
...
@@ -5,6 +5,7 @@ import triton
...
@@ -5,6 +5,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
example_gqa_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
example_gqa_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
typing
import
Optional
@
triton
.
jit
@
triton
.
jit
...
@@ -94,7 +95,7 @@ def triton_kernel(
...
@@ -94,7 +95,7 @@ def triton_kernel(
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
int
|
None
=
None
)
->
torch
.
Tensor
:
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
_
,
n_heads_kv
,
seq_kv
,
_
=
K
.
shape
_
,
n_heads_kv
,
seq_kv
,
_
=
K
.
shape
BLOCK_M
=
64
BLOCK_M
=
64
...
@@ -130,7 +131,7 @@ def main(
...
@@ -130,7 +131,7 @@ def main(
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
groups
:
int
=
8
,
groups
:
int
=
8
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
tune
:
bool
=
False
,
):
):
...
...
examples/attention_sink/benchmark_mha_sink_fwd.py
View file @
bbbf4207
...
@@ -5,6 +5,7 @@ import triton
...
@@ -5,6 +5,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
example_mha_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
example_mha_sink_fwd_bhsd_wgmma_pipelined
import
flashattn
,
ref_program
,
gen_inputs
from
typing
import
Optional
@
triton
.
jit
@
triton
.
jit
...
@@ -93,7 +94,7 @@ def triton_kernel(
...
@@ -93,7 +94,7 @@ def triton_kernel(
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
Out
.
store
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
],
acc
)
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
int
|
None
=
None
)
->
torch
.
Tensor
:
def
triton_program
(
Q
,
K
,
V
,
Sinks
,
window_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
bs
,
n_heads
,
seq_q
,
head_dim
=
Q
.
shape
seq_kv
=
K
.
shape
[
2
]
seq_kv
=
K
.
shape
[
2
]
BLOCK_M
=
64
BLOCK_M
=
64
...
@@ -125,7 +126,7 @@ def main(batch: int = 1,
...
@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
View file @
bbbf4207
...
@@ -81,13 +81,10 @@ def flashattn_fwd(
...
@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
start
=
T
.
max
(
0
,
if
window_size
is
not
None
:
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
q_idx
=
bx
*
block_M
+
i
...
@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
...
@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T
.
clear
(
dk
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
alloc_local
([
1
],
'int32'
)
loop_ed
=
T
.
min
(
if
window_size
is
not
None
:
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
loop_ed
[
0
]
=
T
.
min
(
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
else
:
loop_ed
[
0
]
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
[
0
],
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
...
@@ -444,7 +438,7 @@ def main(BATCH: int = 1,
N_CTX
:
int
=
512
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
64
,
D_HEAD
:
int
=
64
,
groups
:
int
=
2
,
groups
:
int
=
2
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
...
...
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
View file @
bbbf4207
...
@@ -172,14 +172,11 @@ def flashattn(
...
@@ -172,14 +172,11 @@ def flashattn(
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
if
window_size
is
not
None
:
block_N
)
if
window_size
is
not
None
else
0
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
for
k
in
T
.
Pipelined
(
for
k
in
T
.
Pipelined
(
start
[
0
]
,
start
,
end
,
end
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
...
@@ -272,7 +269,7 @@ def main(
...
@@ -272,7 +269,7 @@ def main(
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
groups
:
int
=
8
,
groups
:
int
=
8
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
tune
:
bool
=
False
,
):
):
...
...
examples/attention_sink/example_mha_sink_bwd_bhsd.py
View file @
bbbf4207
...
@@ -78,13 +78,10 @@ def flashattn_fwd(
...
@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
start
=
T
.
max
(
0
,
if
window_size
is
not
None
:
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
q_idx
=
bx
*
block_M
+
i
...
@@ -267,14 +264,10 @@ def flashattn_bwd(
...
@@ -267,14 +264,10 @@ def flashattn_bwd(
T
.
clear
(
dk
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
alloc_local
([
1
],
'int32'
)
loop_ed
=
T
.
min
(
if
window_size
is
not
None
:
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
loop_ed
[
0
]
=
T
.
min
(
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
ceildiv
(
seq_len
,
block_N
))
else
:
loop_ed
[
0
]
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
[
0
],
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
...
@@ -440,7 +433,7 @@ def main(BATCH: int = 1,
H
:
int
=
1
,
H
:
int
=
1
,
N_CTX
:
int
=
512
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
128
,
D_HEAD
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
...
...
examples/attention_sink/example_mha_sink_fwd_bhsd.py
View file @
bbbf4207
...
@@ -162,13 +162,10 @@ def flashattn(
...
@@ -162,13 +162,10 @@ def flashattn(
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
if
window_size
is
not
None
:
block_N
)
if
window_size
is
not
None
else
0
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
for
k
in
T
.
Pipelined
(
start
[
0
]
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
...
@@ -253,7 +250,7 @@ def main(batch: int = 1,
...
@@ -253,7 +250,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
View file @
bbbf4207
...
@@ -165,14 +165,11 @@ def flashattn(
...
@@ -165,14 +165,11 @@ def flashattn(
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
alloc_local
([
1
],
'int32'
)
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
if
window_size
is
not
None
:
block_N
)
if
window_size
is
not
None
else
0
start
[
0
]
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
else
:
start
[
0
]
=
0
for
k
in
T
.
Pipelined
(
for
k
in
T
.
Pipelined
(
start
[
0
]
,
start
,
end
,
end
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
...
@@ -263,7 +260,7 @@ def main(batch: int = 1,
...
@@ -263,7 +260,7 @@ def main(batch: int = 1,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
int
|
None
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
bbbf4207
...
@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
...
@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
16
,
batch
=
8
,
heads
=
16
,
heads
=
8
,
heads_kv
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
4096
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim
=
128
,
dim_v
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
sparse_ratio
=
0.8
,
...
@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
...
@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch
=
16
,
batch
=
16
,
heads
=
16
,
heads
=
16
,
heads_kv
=
8
,
heads_kv
=
8
,
max_cache_seqlen
=
4096
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim
=
128
,
dim_v
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
sparse_ratio
=
0.8
,
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
bbbf4207
...
@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
...
@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return
x_fp8
return
x_fp8
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
):
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
if
batch_sizes
is
None
:
batch_sizes
=
[
2048
,
6144
]
if
dtype
==
"float"
:
if
dtype
==
"float"
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
elif
dtype
==
"float16"
:
elif
dtype
==
"float16"
:
...
@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
...
@@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
batch_sizes
=
torch
.
tensor
(
[
2048
,
6144
]
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
batch_sizes
=
torch
.
tensor
(
batch_sizes
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
M_max
=
int
(
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
)
M_max
=
int
(
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
)
print
(
"batch_sizes:"
,
batch_sizes
)
print
(
"batch_sizes:"
,
batch_sizes
)
...
...
examples/cast/test_example_cast.py
View file @
bbbf4207
...
@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
...
@@ -4,11 +4,12 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
8192
,
N
=
2048
,
BG
=
2
,
blk_m
=
8
)
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
def
test_example_per_token_cast_to_fp8
():
example_per_token_cast_to_fp8
.
main
(
M
=
8192
,
N
=
2048
,
blk_m
=
8
)
example_per_token_cast_to_fp8
.
main
(
M
=
2048
,
N
=
512
,
blk_m
=
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/deepseek_v32/fp8_lighting_indexer.py
View file @
bbbf4207
...
@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
...
@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared
=
T
.
alloc_shared
([
block_N
,
index_dim
],
dtype
)
index_k_shared
=
T
.
alloc_shared
([
block_N
,
index_dim
],
dtype
)
index_k_scale_fragment
=
T
.
alloc_fragment
([
block_N
],
accum_dtype
)
index_k_scale_fragment
=
T
.
alloc_fragment
([
block_N
],
accum_dtype
)
s
=
T
.
alloc_fragment
([
block_N
,
block_Q
*
heads
],
accum_dtype
)
s
=
T
.
alloc_fragment
([
block_N
,
block_Q
*
heads
],
accum_dtype
)
s_reshaped
=
T
.
alloc_fragment
([
block_N
,
block_Q
,
heads
],
accum_dtype
)
s_reshaped
=
T
.
reshape
(
s
,
(
block_N
,
block_Q
,
heads
)
)
logits
=
T
.
alloc_fragment
([
block_N
,
block_Q
],
accum_dtype
)
logits
=
T
.
alloc_fragment
([
block_N
,
block_Q
],
accum_dtype
)
weights
=
T
.
alloc_fragment
([
block_Q
,
heads
],
accum_dtype
)
weights
=
T
.
alloc_fragment
([
block_Q
,
heads
],
accum_dtype
)
...
@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
...
@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for
bn_i
,
bq_i
,
h_i
in
T
.
Parallel
(
block_N
,
block_Q
,
heads
):
for
bn_i
,
bq_i
,
h_i
in
T
.
Parallel
(
block_N
,
block_Q
,
heads
):
s_reshaped
[
bn_i
,
bq_i
,
s_reshaped
[
bn_i
,
bq_i
,
h_i
]
=
(
T
.
max
(
s
[
bn_i
,
bq_i
*
heads
+
h_i
],
0
)
*
h_i
]
=
(
T
.
max
(
s
_reshaped
[
bn_i
,
bq_i
,
h_i
],
0
)
*
weights
[
bq_i
,
h_i
])
*
index_k_scale_fragment
[
bn_i
]
weights
[
bq_i
,
h_i
])
*
index_k_scale_fragment
[
bn_i
]
T
.
reduce_sum
(
s_reshaped
,
logits
,
dim
=-
1
,
clear
=
True
)
T
.
reduce_sum
(
s_reshaped
,
logits
,
dim
=-
1
,
clear
=
True
)
...
@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
...
@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def
test_fp8_lighting_indexer
(
S
=
4096
,
SKV
=
8192
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
):
def
test_fp8_lighting_indexer
(
S
=
4096
,
SKV
=
8192
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
):
# initial random seed to make the performance reproducible
torch
.
manual_seed
(
0
)
q
=
torch
.
randn
(
S
,
H
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
q
=
torch
.
randn
(
S
,
H
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
kv
=
torch
.
randn
(
SKV
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
kv
=
torch
.
randn
(
SKV
,
D
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
to
(
torch
.
bfloat16
)
weights
=
torch
.
randn
(
S
,
H
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
weights
=
torch
.
randn
(
S
,
H
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
...
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
View file @
bbbf4207
# ruff: noqa
# ruff: noqa
import
tilelang.testing
import
tilelang.testing
from
topk_selector
import
test_
topk_selector
import
topk_selector
from
fp8_lighting_indexer
import
test_
fp8_lighting_indexer
import
fp8_lighting_indexer
from
sparse_mla_fwd
import
test_
sparse_mla_fwd
import
sparse_mla_fwd
from
sparse_mla_fwd_pipelined
import
test_
sparse_mla_fwd_pipelined
import
sparse_mla_fwd_pipelined
from
sparse_mla_bwd
import
test_
sparse_mla_bwd
import
sparse_mla_bwd
def
test_example_topk_selector
():
def
test_example_topk_selector
():
test_topk_selector
()
topk_selector
.
test_topk_selector
()
def
test_example_fp8_lighting_indexer
():
def
test_example_fp8_lighting_indexer
():
test_fp8_lighting_indexer
(
S
=
1024
,
SKV
=
2048
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
)
fp8_lighting_indexer
.
test_fp8_lighting_indexer
(
S
=
512
,
SKV
=
1024
,
H
=
32
,
HKV
=
1
,
D
=
64
,
kv_stride
=
1
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_fwd
():
def
test_example_sparse_mla_fwd
():
# small shapes for testing
# small shapes for testing
test_sparse_mla_fwd
(
sparse_mla_fwd
.
test_sparse_mla_fwd
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
...
@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
...
@@ -28,15 +28,15 @@ def test_example_sparse_mla_fwd():
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_fwd_pipelined
():
def
test_example_sparse_mla_fwd_pipelined
():
# small shapes for testing
# small shapes for testing
test_sparse_mla_fwd_pipelined
(
sparse_mla_fwd_pipelined
.
test_sparse_mla_fwd_pipelined
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
S
=
256
,
SKV
=
512
,
H
=
64
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_sparse_mla_bwd
():
def
test_example_sparse_mla_bwd
():
test_sparse_mla_bwd
(
sparse_mla_bwd
.
test_sparse_mla_bwd
(
S
=
256
,
SKV
=
1024
,
H
=
64
,
HKV
=
1
,
DQKV
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
S
=
256
,
SKV
=
512
,
H
=
64
,
HKV
=
1
,
DQKV
=
576
,
DV
=
512
,
topk
=
256
,
check_correctness
=
False
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/elementwise/test_example_elementwise.py
View file @
bbbf4207
import
tilelang.testing
import
tilelang.testing
import
example_elementwise_add
import
example_elementwise_add
import
example_elementwise_add_tma_1d
def
test_example_elementwise_add
():
def
test_example_elementwise_add
():
example_elementwise_add
.
main
()
example_elementwise_add
.
main
()
def
test_example_elementwise_add_tma_1d
():
example_elementwise_add_tma_1d
.
main
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
examples/flash_attention/example_gqa_bwd_tma_reduce.py
View file @
bbbf4207
...
@@ -5,6 +5,8 @@ import tilelang.language as T
...
@@ -5,6 +5,8 @@ import tilelang.language as T
from
tilelang.contrib
import
nvcc
from
tilelang.contrib
import
nvcc
import
argparse
import
argparse
tilelang
.
disable_cache
()
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
...
@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
...
@@ -44,7 +46,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T
.
fill
(
scores_max
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
loop_range
=
(
loop_range
=
(
T
.
ceildiv
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
...
@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
...
@@ -53,7 +57,7 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc
_s
.
dtype
))
T
.
Cast
(
acc
um_
dtype
,
-
1e30
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -265,17 +269,17 @@ def flashattn_bwd_atomic_add(batch,
...
@@ -265,17 +269,17 @@ def flashattn_bwd_atomic_add(batch,
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
})
def
flashattn_bwd_split
(
batch
,
def
flashattn_bwd_split
_novarlen
(
batch
,
heads
,
heads
,
seq_len
,
seq_len
,
dim_qk
,
dim_qk
,
dim_v
,
dim_v
,
is_causal
,
is_causal
,
block_M
,
block_M
,
block_N
,
block_N
,
threads
=
256
,
threads
=
256
,
num_stages
=
2
,
num_stages
=
2
,
groups
=
1
):
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
...
@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
...
@@ -424,7 +428,7 @@ class _attention(torch.autograd.Function):
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
else
:
else
:
kernel
=
flashattn_bwd_split
(
kernel
=
flashattn_bwd_split
_novarlen
(
BATCH
,
BATCH
,
H
,
H
,
N_CTX
,
N_CTX
,
...
...
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
View file @
bbbf4207
This diff is collapsed.
Click to expand it.
examples/flash_attention/example_gqa_fwd_varlen.py
View file @
bbbf4207
...
@@ -24,21 +24,32 @@ def attention_ref(
...
@@ -24,21 +24,32 @@ def attention_ref(
dtype_og
=
q
.
dtype
dtype_og
=
q
.
dtype
if
upcast
:
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
b
,
T
,
Hq
,
D
=
q
.
shape
scale
=
(
1.0
/
dim
)
**
0.5
S
=
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
scale
=
(
1.0
/
D
)
**
0.5
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
left
,
right
=
window_size
left
=
S
if
left
is
None
or
left
<
0
else
int
(
left
)
right
=
S
if
right
is
None
or
right
<
0
else
int
(
right
)
t_idx
=
torch
.
arange
(
T
,
device
=
scores
.
device
)[:,
None
]
s_idx
=
torch
.
arange
(
S
,
device
=
scores
.
device
)[
None
,
:]
visible_ts
=
(
s_idx
>=
(
t_idx
-
left
))
&
(
s_idx
<=
(
t_idx
+
right
))
visible_mask
=
visible_ts
.
unsqueeze
(
0
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
k_keep
=
rearrange
(
key_padding_mask
,
"b s -> b 1 1 s"
)
visible_mask
=
visible_mask
&
k_keep
neg_inf
=
torch
.
finfo
(
scores
.
dtype
).
min
scores
=
scores
*
scale
scores
=
scores
*
scale
scores
=
scores
.
masked_fill
(
~
visible_mask
,
neg_inf
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
q_keep
=
rearrange
(
query_padding_mask
,
"b t -> b 1 t 1"
)
attention
=
attention
.
masked_fill
(
~
q_keep
,
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
output
.
masked_fill
_
(
rearrange
(
~
query_padding_mask
,
"b
s
-> b
s
1 1"
),
0.0
)
output
=
output
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b
t
-> b
t
1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
...
@@ -91,53 +102,53 @@ def flashattn(batch_size,
...
@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
})
batch_idx
=
bz
batch_idx
=
bz
head_idx
=
by
head_idx
=
by
kv_head_idx
=
head_idx
//
groups
kv_head_idx
=
head_idx
//
groups
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
kv_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
kv_current_seqlen
=
k_end_idx
-
kv_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
T
.
copy
(
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
Q_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
>=
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
q_current_seqlen
+
(
bx
+
1
)
*
block_M
,
block_N
),
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K_unpad
[
k_start_idx
+
k
*
block_N
:
k_start_idx
+
(
k
+
1
)
*
block_N
,
K_unpad
[
k
v
_start_idx
+
k
*
block_N
:
k
v
_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
kv_head_idx
,
:],
K_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
k_current_seqlen
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
acc_s
[
i
,
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
<
k
*
block_N
+
j
)
or
k
*
block_
N
+
j
>=
k
_current_seqlen
),
(
bx
*
block_
M
+
i
>=
q
_current_seqlen
or
-
T
.
infinity
(
acc_s
.
dtype
)
,
0
)
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
else
:
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
k
*
block_N
+
j
>=
k
v
_current_seqlen
),
-
1e9
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -145,6 +156,9 @@ def flashattn(batch_size,
...
@@ -145,6 +156,9 @@ def flashattn(batch_size,
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -158,11 +172,8 @@ def flashattn(batch_size,
...
@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
copy
(
V_unpad
[
v_start_idx
+
k
*
block_N
:
v_start_idx
+
(
k
+
1
)
*
block_N
,
V_unpad
[
k
v_start_idx
+
k
*
block_N
:
k
v_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
kv_head_idx
,
:],
V_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
v_current_seqlen
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -191,8 +202,7 @@ def main(batch: int = 1,
...
@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
is_causal
:
if
causal
:
total_flops
*=
0.5
total_flops
*=
0.5
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
...
@@ -201,9 +211,9 @@ def main(batch: int = 1,
...
@@ -201,9 +211,9 @@ def main(batch: int = 1,
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
...
@@ -236,10 +246,10 @@ def main(batch: int = 1,
...
@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads
,
heads
,
dim
,
dim
,
is_causal
,
is_causal
,
block_M
=
64
,
block_M
=
128
,
block_N
=
64
,
block_N
=
128
,
num_stages
=
1
,
num_stages
=
2
,
threads
=
128
)
threads
=
256
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
...
@@ -255,7 +265,9 @@ def main(batch: int = 1,
...
@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
))
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
),
_n_warmup
=
5
,
_n_repeat
=
5
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
...
...
examples/flash_attention/test_example_flash_attention.py
View file @
bbbf4207
...
@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
...
@@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd
():
def
test_example_mha_bwd
():
example_mha_bwd
.
main
(
BATCH
=
1
)
example_mha_bwd
.
main
(
BATCH
=
1
,
H
=
16
,
N_CTX
=
512
,
D_HEAD
=
64
,
causal
=
False
,
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd_bhsd
():
def
test_example_mha_bwd_bhsd
():
example_mha_bwd_bhsd
.
main
(
BATCH
=
1
)
example_mha_bwd_bhsd
.
main
(
BATCH
=
1
,
H
=
16
,
N_CTX
=
512
,
D_HEAD
=
64
,
causal
=
False
,
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_bwd_wgmma_pipelined
():
def
test_example_mha_bwd_wgmma_pipelined
():
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
)
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
,
H
=
32
,
N_CTX
=
256
,
D_HEAD
=
64
,
causal
=
False
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
...
@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
...
@@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_fwd_varlen
():
def
test_example_mha_fwd_varlen
():
example_mha_fwd_varlen
.
main
()
example_mha_fwd_varlen
.
main
(
batch
=
4
,
heads
=
16
,
seq_len
=
512
,
dim
=
64
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment