Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
You need to sign in or sign up before continuing.
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
779 additions
and
937 deletions
+779
-937
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-12
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
...ark/blocksparse_attention/benchmark_library_dense_fmha.py
+5
-8
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
...ksparse_attention/benchmark_tilelang_block_sparse_fmha.py
+31
-41
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
...locksparse_attention/benchmark_torch_block_sparse_fmha.py
+10
-15
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
...ocksparse_attention/benchmark_triton_block_sparse_fmha.py
+10
-30
benchmark/mamba2/benchmark_mamba_chunk_scan.py
benchmark/mamba2/benchmark_mamba_chunk_scan.py
+114
-93
benchmark/matmul/benchmark_matmul.py
benchmark/matmul/benchmark_matmul.py
+8
-8
benchmark/matmul/benchmark_matmul_intrinsic.py
benchmark/matmul/benchmark_matmul_intrinsic.py
+17
-23
benchmark/matmul/benchmark_matmul_sp.py
benchmark/matmul/benchmark_matmul_sp.py
+24
-28
benchmark/matmul_fp8/benchmark_matmul.py
benchmark/matmul_fp8/benchmark_matmul.py
+7
-8
docs/conf.py
docs/conf.py
+11
-20
examples/amd/example_amd_flash_attn_bwd.py
examples/amd/example_amd_flash_attn_bwd.py
+96
-110
examples/amd/example_amd_flash_attn_fwd.py
examples/amd/example_amd_flash_attn_fwd.py
+42
-63
examples/analyze/example_conv_analyze.py
examples/analyze/example_conv_analyze.py
+14
-31
examples/analyze/example_gemm_analyze.py
examples/analyze/example_gemm_analyze.py
+3
-3
examples/attention_sink/benchmark_gqa_sink_fwd.py
examples/attention_sink/benchmark_gqa_sink_fwd.py
+20
-28
examples/attention_sink/benchmark_mha_sink_fwd.py
examples/attention_sink/benchmark_mha_sink_fwd.py
+29
-35
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
+136
-150
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
+75
-92
examples/attention_sink/example_mha_sink_bwd_bhsd.py
examples/attention_sink/example_mha_sink_bwd_bhsd.py
+125
-139
No files found.
.pre-commit-config.yaml
View file @
29051439
...
...
@@ -39,19 +39,9 @@ repos:
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.14.7
# sync with requirements-lint.txt
hooks
:
-
id
:
ruff-format
-
id
:
ruff-check
args
:
[
--fix
,
--exit-non-zero-on-fix
]
-
repo
:
https://github.com/google/yapf
rev
:
v0.43.0
# sync with requirements-lint.txt
hooks
:
-
id
:
yapf
name
:
yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args
:
[
--in-place
,
docs/conf.py
]
always_run
:
true
pass_filenames
:
false
-
id
:
yapf
args
:
[
--recursive
,
--in-place
]
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
# sync with requirements-lint.txt
hooks
:
...
...
@@ -62,4 +52,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$|
^.*\brequirements\b.*\.txt$
)
\ No newline at end of file
)
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
View file @
29051439
...
...
@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
import
flash_attn
...
...
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
num_stages
=
2
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
program
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
program
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
4
)
def
benchmark_fn
():
...
...
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
def
benchmark_fn
():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
return
ref_output
ref_latency
=
do_bench
(
...
...
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
...
...
@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'-inf'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -191,24 +186,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -253,7 +236,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
benchmark/mamba2/benchmark_mamba_chunk_scan.py
View file @
29051439
...
...
@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
out
=
torch
.
einsum
(
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
out_prev
=
(
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
...
...
@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def
chunk_scan_helion
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
@
helion
.
kernel
()
def
helion_mamba2_chunk_scan_kernel
(
cb
:
torch
.
Tensor
,
...
...
@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype
=
cb
.
dtype
accum_dtype
=
torch
.
float32
assert
(
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
dtype
)
assert
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
dtype
out
=
torch
.
empty_like
(
x
)
...
...
@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for
tile_h
,
tile_m
,
tile_n
,
tile_b
,
tile_c
in
hl
.
tile
(
[
nheads
,
chunk_size
,
headdim
,
batch
,
nchunks
],
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
):
acc_o
=
hl
.
zeros
([
tile_m
,
tile_n
],
dtype
=
accum_dtype
)
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_m
].
to
(
torch
.
float32
)
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_m
].
to
(
torch
.
float32
)
scale_m_local
=
torch
.
exp2
(
dA_cumsum_local_m
*
p
)
C_local
=
C
[
...
...
@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m
,
tile_k
,
]
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
dt_local
=
dt
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
=
(
cb_local
*
dt_local
[
None
,
:]).
to
(
dtype
)
pred
=
(
tile_m
.
index
+
0
)[:,
None
]
>=
(
tile_k
.
index
+
0
)[
None
,
:]
...
...
@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o
=
hl
.
dot
(
cb_local
,
x_local
,
acc
=
acc_o
)
D_local
=
D
[
tile_h
.
begin
].
to
(
torch
.
float32
)
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
].
to
(
torch
.
float32
)
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
].
to
(
torch
.
float32
)
acc_o
+=
x_residual
*
D_local
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
return
out
...
...
@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -198,19 +187,21 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
):
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
,
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
...
@@ -218,20 +209,20 @@ def chunk_scan_fwd(batch,
@
T
.
prim_func
def
main
(
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
,
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
)
,
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
,
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
})
T
.
annotate_layout
(
{
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
}
)
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
...
@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
cb_shared
)
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_residual_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
...
@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
nchunks
=
math
.
ceil
(
seq_len
/
chunk_size
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
...
...
@@ -360,8 +382,7 @@ if __name__ == "__main__":
D
=
torch
.
randn
(
heads
).
half
().
cuda
()
print
(
"Benchmarking Triton..."
)
triton_latency
=
do_bench
(
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
triton_latency
=
do_bench
(
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
print
(
f
"Triton TFlops:
{
total_flops
/
triton_latency
*
1e-9
}
"
)
print
(
"Benchmarking Helion..."
)
...
...
benchmark/matmul/benchmark_matmul.py
View file @
29051439
...
...
@@ -6,6 +6,7 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
from
tilelang
import
jit
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
...
...
@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -159,9 +160,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
benchmark/matmul/benchmark_matmul_intrinsic.py
View file @
29051439
...
...
@@ -6,7 +6,8 @@ import tilelang as tl
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.autotuner
import
autotune
import
itertools
...
...
@@ -103,12 +104,11 @@ def tl_matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -116,10 +116,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
...
...
@@ -127,7 +129,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -137,7 +138,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for
config
in
configs
:
print
(
config
)
else
:
iter_params
=
dict
(
block_row_warps
=
[
1
,
2
,
4
],
block_col_warps
=
[
1
,
2
,
4
],
...
...
@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage
=
[
0
,
2
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
ref_prog
=
ref_program
,
skip_check
=
True
,
)
@
tl
.
jit
(
out_idx
=
[
2
],)
@
tl
.
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -291,13 +290,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--with_roller"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
parser
.
add_argument
(
"--with_roller"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
benchmark/matmul/benchmark_matmul_sp.py
View file @
29051439
...
...
@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num
,
policy
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -81,7 +82,8 @@ def get_configs(M, N, K):
"thread_num"
:
c
[
4
],
"policy"
:
c
[
5
],
"enable_rasterization"
:
c
[
6
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
in_dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
}
)
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
# Load a sub-block of A from global memory into A_shared
...
...
@@ -241,18 +243,13 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--disable_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--bench_torch_sparse"
,
type
=
str
,
choices
=
[
'
cutlass
'
,
'
cusparselt
'
],
choices
=
[
"
cutlass
"
,
"
cusparselt
"
],
default
=
None
,
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -274,7 +271,8 @@ if __name__ == "__main__":
if
args
.
bench_torch_sparse
is
not
None
:
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
if
args
.
bench_torch_sparse
==
'cutlass'
:
if
args
.
bench_torch_sparse
==
"cutlass"
:
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
A_sp
=
to_sparse_semi_structured
(
A
,
transposed
=
False
)
torch_sparse_latency
=
do_bench
(
lambda
:
A_sp
@
B
)
...
...
@@ -285,8 +283,6 @@ if __name__ == "__main__":
print
(
f
"Best config:
{
best_config
}
"
)
if
args
.
bench_torch_sparse
is
not
None
:
print
(
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Reference Dense TFlops:
{
total_flops
/
ref_latency
*
1e-9
:.
3
f
}
"
)
benchmark/matmul_fp8/benchmark_matmul.py
View file @
29051439
...
...
@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -164,9 +164,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -181,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
docs/conf.py
View file @
29051439
...
...
@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension"
,
]
autoapi_type
=
'
python
'
autoapi_dirs
=
[
'
../tilelang
'
]
autoapi_type
=
"
python
"
autoapi_dirs
=
[
"
../tilelang
"
]
autoapi_options
=
[
'
members
'
,
'
undoc-members
'
,
'
show-inheritance
'
,
'
show-module-summary
'
,
'
special-members
'
,
"
members
"
,
"
undoc-members
"
,
"
show-inheritance
"
,
"
show-module-summary
"
,
"
special-members
"
,
]
autoapi_keep_files
=
False
# Useful for debugging the generated rst files
autoapi_generate_api_docs
=
True
autodoc_typehints
=
'
description
'
autodoc_typehints
=
"
description
"
autoapi_ignore
=
[
"*language/ast*"
,
"*version*"
,
"*libinfo*"
,
"*parser*"
]
source_suffix
=
{
'.rst'
:
'restructuredtext'
,
'.md'
:
'markdown'
,
}
source_suffix
=
{
".rst"
:
"restructuredtext"
,
".md"
:
"markdown"
}
myst_enable_extensions
=
[
"colon_fence"
,
"deflist"
,
]
myst_enable_extensions
=
[
"colon_fence"
,
"deflist"
]
redirects
=
{
"get_started/try_out"
:
"../index.html#getting-started"
}
...
...
@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
footer_copyright
=
"© 2025-2026 TileLang"
footer_note
=
" "
html_theme_options
=
{
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
,
}
html_theme_options
=
{
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
}
header_links
=
[
(
"Home"
,
"https://github.com/tile-ai/tilelang"
),
...
...
examples/amd/example_amd_flash_attn_bwd.py
View file @
29051439
...
...
@@ -11,22 +11,20 @@ import time
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K_ref
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V_ref
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K_ref
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K_ref
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V_ref
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V_ref
)
lse
=
torch
.
logsumexp
(
scores
,
dim
=-
1
).
float
()
return
output
,
lse
...
...
@@ -45,23 +43,23 @@ def get_fwd_configs():
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
({
"block_
M
"
:
m
,
"block_N
"
:
n
,
"num_split_q
"
:
s
,
"threads"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization
"
:
r
,
"k_pack
"
:
k
,
"panel_size"
:
p
,
"qk
_coalesced_width"
:
qk
w
,
"v_coalesced_width"
:
vw
,
}
)
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_
N
"
:
n
,
"num_split_q
"
:
s
,
"threads
"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization"
:
r
,
"k_pack
"
:
k
,
"panel_size
"
:
p
,
"qk_coalesced_width"
:
qkw
,
"v
_coalesced_width"
:
v
w
,
}
)
return
valid_configs
...
...
@@ -85,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -97,11 +95,11 @@ def fast_flashattn(
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
LSE
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
LSE
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
):
with
T
.
Kernel
(
num_split_q
,
batch
*
heads
,
threads
=
threads
)
as
(
b_split
,
byz_combined
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
(
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
...
...
@@ -216,8 +202,7 @@ def fast_flashattn(
for
i
in
T
.
Parallel
(
block_M
):
if
q_block_offset
+
i
<
seq_len
:
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
LSE
[
bz
,
by
,
q_block_offset
+
i
]
=
lse_val
bx_loop_var
=
current_bx
+
num_split_q
...
...
@@ -234,16 +219,17 @@ def get_bwd_configs():
panel_size
=
[
7
,
8
,
9
,
10
]
configs
=
[]
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
enable_rasterization
,
panel_size
):
configs
.
append
({
"block_M"
:
m
,
"block_N"
:
n
,
"num_stages"
:
stages
,
"threads"
:
t
,
"enable_rasterization"
:
r
,
"panel_size"
:
p
,
})
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
enable_rasterization
,
panel_size
):
configs
.
append
(
{
"block_M"
:
m
,
"block_N"
:
n
,
"num_stages"
:
stages
,
"threads"
:
t
,
"enable_rasterization"
:
r
,
"panel_size"
:
p
,
}
)
return
configs
...
...
@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
with
T
.
Kernel
(
batch
,
heads
,
T
.
ceildiv
(
seq_len
,
blk
))
as
(
bz
,
bx
,
by
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
...
...
@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
autotune
(
configs
=
get_bwd_configs
(),
cache_input_tensors
=
True
)
@
tilelang
.
jit
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
,
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd_kernel
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
)):
def
flash_bwd_kernel
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
...
...
@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q_shared
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
exp
(
qkT
[
i
,
j
]
*
sm_scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
],
0.0
)
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
],
0.0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
clear
(
dP
)
T
.
gemm
(
V_shared
,
do_shared
,
dP
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T
.
copy
(
P_acc
,
p_cast
)
T
.
gemm
(
p_cast
,
do_shared
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
p_cast
[
i
,
j
]
=
P_acc
[
i
,
j
]
*
(
dP
[
i
,
j
]
-
delta_shared
[
j
])
*
sm_scale
...
...
@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def
flash_bwd_post
(
dQ_in
:
T
.
Tensor
(
shape
,
accum_dtype
),
dQ_out
:
T
.
Tensor
(
shape
,
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
copy
(
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
...
...
@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return
np
.
median
(
times
)
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
device
=
"cuda"
dtype
=
torch
.
float16
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
print
(
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
print
(
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
flops_per_gemm
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
5
*
flops_per_gemm
...
...
@@ -517,22 +508,19 @@ def main(batch: int = 1,
o_ref
.
backward
(
dO
)
print
(
"Verifying backward pass correctness..."
)
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dq_close
:
print
(
"dQ is correct."
)
else
:
print
(
"dQ mismatch detected."
)
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dk_close
:
print
(
"dK is correct."
)
else
:
print
(
"dK mismatch detected."
)
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dv_close
:
print
(
"dV is correct."
)
else
:
...
...
@@ -553,9 +541,7 @@ def main(batch: int = 1,
torch
.
cuda
.
synchronize
()
ref_latency
=
benchmark_function
(
run_reference_fwd_bwd
,
warmup
=
10
,
repeat
=
100
)
print
(
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
print
(
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
def
run_complete_fwd_bwd
():
o_tl_bench
,
lse_tl_bench
=
fwd_kernel
(
q
,
k
,
v
)
...
...
@@ -593,12 +579,12 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
1024
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
1024
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/amd/example_amd_flash_attn_fwd.py
View file @
29051439
...
...
@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors
=
[]
for
param
in
params
:
if
hasattr
(
param
,
'
shape
'
)
and
hasattr
(
param
,
'
dtype
'
):
if
hasattr
(
param
,
"
shape
"
)
and
hasattr
(
param
,
"
dtype
"
):
# Force creation on GPU device
shape
=
[
int
(
s
)
for
s
in
param
.
shape
]
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
'
cuda
'
)
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
"
cuda
"
)
tensors
.
append
(
tensor
)
else
:
tensors
.
append
(
param
)
...
...
@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -58,23 +56,23 @@ def get_configs():
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
({
"block_
M
"
:
m
,
"block_N
"
:
n
,
"num_split_q
"
:
s
,
"threads"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization
"
:
r
,
"k_pack
"
:
k
,
"panel_size"
:
p
,
"qk
_coalesced_width"
:
qk
w
,
"v_coalesced_width"
:
vw
,
}
)
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_
N
"
:
n
,
"num_split_q
"
:
s
,
"threads
"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization"
:
r
,
"k_pack
"
:
k
,
"panel_size
"
:
p
,
"qk_coalesced_width"
:
qkw
,
"v
_coalesced_width"
:
v
w
,
}
)
return
valid_configs
...
...
@@ -98,7 +96,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -110,10 +108,10 @@ def fast_flashattn(
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
num_split_q
,
batch
*
heads
,
threads
=
threads
)
as
(
b_split
,
byz_combined
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -147,32 +145,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
...
...
@@ -222,13 +209,7 @@ def fast_flashattn(
return
main
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
...
...
@@ -250,18 +231,16 @@ def main(batch: int = 1,
print
(
f
"Reference (PyTorch):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
print
(
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/analyze/example_conv_analyze.py
View file @
29051439
...
...
@@ -25,22 +25,7 @@ def check_hopper():
return
False
def
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -50,13 +35,11 @@ def kernel(N,
@
T
.
prim_func
def
conv
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
out_shared
:
make_swizzled_layout
(
out_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
make_swizzled_layout
(
out_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -81,10 +66,8 @@ def kernel(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
examples/analyze/example_gemm_analyze.py
View file @
29051439
...
...
@@ -20,9 +20,9 @@ def kernel(
@
T
.
prim_func
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/attention_sink/benchmark_gqa_sink_fwd.py
View file @
29051439
...
...
@@ -51,8 +51,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
...
@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
...
...
@@ -137,12 +137,11 @@ def main(
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -170,15 +169,14 @@ def main(
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
if
torch
.
allclose
(
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
):
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"Checks for triton passed.✅"
)
else
:
print
(
"Checks for triton failed.❌"
)
...
...
@@ -198,20 +196,14 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/benchmark_mha_sink_fwd.py
View file @
29051439
...
...
@@ -50,8 +50,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
...
@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -163,15 +164,14 @@ def main(batch: int = 1,
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
warmup
=
500
)
...
...
@@ -184,19 +184,13 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
View file @
29051439
This diff is collapsed.
Click to expand it.
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
View file @
29051439
...
...
@@ -23,9 +23,11 @@ def get_configs():
rep
=
100
,
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
...
...
@@ -41,12 +43,11 @@ def flashattn(
threads
=
256
,
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
...
...
@@ -68,13 +69,12 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -89,18 +89,18 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -112,8 +112,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -128,19 +127,19 @@ def flashattn(
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Sinks
:
T
.
Tensor
([
heads
],
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Sinks
:
T
.
Tensor
([
heads
],
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
...
@@ -157,58 +156,58 @@ def flashattn(
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
T
.
annotate_layout
({
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
})
T
.
annotate_layout
(
{
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
in
T
.
Parallel
(
block_M
):
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]]):
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
...
...
@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
'cuda'
)
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
"cuda"
)
return
query
,
key
,
value
,
sinks
...
...
@@ -277,12 +268,11 @@ def main(
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -310,15 +300,14 @@ def main(
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
# Benchmark tilelang
...
...
@@ -329,20 +318,14 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
View file @
29051439
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment