Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
779 additions
and
937 deletions
+779
-937
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-12
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
...ark/blocksparse_attention/benchmark_library_dense_fmha.py
+5
-8
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
...ksparse_attention/benchmark_tilelang_block_sparse_fmha.py
+31
-41
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
...locksparse_attention/benchmark_torch_block_sparse_fmha.py
+10
-15
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
...ocksparse_attention/benchmark_triton_block_sparse_fmha.py
+10
-30
benchmark/mamba2/benchmark_mamba_chunk_scan.py
benchmark/mamba2/benchmark_mamba_chunk_scan.py
+114
-93
benchmark/matmul/benchmark_matmul.py
benchmark/matmul/benchmark_matmul.py
+8
-8
benchmark/matmul/benchmark_matmul_intrinsic.py
benchmark/matmul/benchmark_matmul_intrinsic.py
+17
-23
benchmark/matmul/benchmark_matmul_sp.py
benchmark/matmul/benchmark_matmul_sp.py
+24
-28
benchmark/matmul_fp8/benchmark_matmul.py
benchmark/matmul_fp8/benchmark_matmul.py
+7
-8
docs/conf.py
docs/conf.py
+11
-20
examples/amd/example_amd_flash_attn_bwd.py
examples/amd/example_amd_flash_attn_bwd.py
+96
-110
examples/amd/example_amd_flash_attn_fwd.py
examples/amd/example_amd_flash_attn_fwd.py
+42
-63
examples/analyze/example_conv_analyze.py
examples/analyze/example_conv_analyze.py
+14
-31
examples/analyze/example_gemm_analyze.py
examples/analyze/example_gemm_analyze.py
+3
-3
examples/attention_sink/benchmark_gqa_sink_fwd.py
examples/attention_sink/benchmark_gqa_sink_fwd.py
+20
-28
examples/attention_sink/benchmark_mha_sink_fwd.py
examples/attention_sink/benchmark_mha_sink_fwd.py
+29
-35
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
+136
-150
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
+75
-92
examples/attention_sink/example_mha_sink_bwd_bhsd.py
examples/attention_sink/example_mha_sink_bwd_bhsd.py
+125
-139
No files found.
.pre-commit-config.yaml
View file @
29051439
...
...
@@ -39,19 +39,9 @@ repos:
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.14.7
# sync with requirements-lint.txt
hooks
:
-
id
:
ruff-format
-
id
:
ruff-check
args
:
[
--fix
,
--exit-non-zero-on-fix
]
-
repo
:
https://github.com/google/yapf
rev
:
v0.43.0
# sync with requirements-lint.txt
hooks
:
-
id
:
yapf
name
:
yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args
:
[
--in-place
,
docs/conf.py
]
always_run
:
true
pass_filenames
:
false
-
id
:
yapf
args
:
[
--recursive
,
--in-place
]
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
# sync with requirements-lint.txt
hooks
:
...
...
@@ -62,4 +52,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$|
^.*\brequirements\b.*\.txt$
)
\ No newline at end of file
)
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
View file @
29051439
...
...
@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
import
flash_attn
...
...
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
num_stages
=
2
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
program
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
program
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
4
)
def
benchmark_fn
():
...
...
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
def
benchmark_fn
():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
return
ref_output
ref_latency
=
do_bench
(
...
...
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
...
...
@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'-inf'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -191,24 +186,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -253,7 +236,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
benchmark/mamba2/benchmark_mamba_chunk_scan.py
View file @
29051439
...
...
@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
out
=
torch
.
einsum
(
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
out_prev
=
(
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
...
...
@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def
chunk_scan_helion
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
@
helion
.
kernel
()
def
helion_mamba2_chunk_scan_kernel
(
cb
:
torch
.
Tensor
,
...
...
@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype
=
cb
.
dtype
accum_dtype
=
torch
.
float32
assert
(
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
dtype
)
assert
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
dtype
out
=
torch
.
empty_like
(
x
)
...
...
@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for
tile_h
,
tile_m
,
tile_n
,
tile_b
,
tile_c
in
hl
.
tile
(
[
nheads
,
chunk_size
,
headdim
,
batch
,
nchunks
],
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
):
acc_o
=
hl
.
zeros
([
tile_m
,
tile_n
],
dtype
=
accum_dtype
)
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_m
].
to
(
torch
.
float32
)
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_m
].
to
(
torch
.
float32
)
scale_m_local
=
torch
.
exp2
(
dA_cumsum_local_m
*
p
)
C_local
=
C
[
...
...
@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m
,
tile_k
,
]
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
dt_local
=
dt
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
=
(
cb_local
*
dt_local
[
None
,
:]).
to
(
dtype
)
pred
=
(
tile_m
.
index
+
0
)[:,
None
]
>=
(
tile_k
.
index
+
0
)[
None
,
:]
...
...
@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o
=
hl
.
dot
(
cb_local
,
x_local
,
acc
=
acc_o
)
D_local
=
D
[
tile_h
.
begin
].
to
(
torch
.
float32
)
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
].
to
(
torch
.
float32
)
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
].
to
(
torch
.
float32
)
acc_o
+=
x_residual
*
D_local
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
return
out
...
...
@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -198,19 +187,21 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
):
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
ngroups
,
nheads
,
headdim
,
dstate
,
block_M
=
64
,
block_N
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
num_stages
=
2
,
threads
=
128
,
):
dtype
=
"float16"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
...
@@ -218,20 +209,20 @@ def chunk_scan_fwd(batch,
@
T
.
prim_func
def
main
(
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
,
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
)
,
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
,
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
})
T
.
annotate_layout
(
{
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
}
)
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
...
@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
cb_shared
)
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_residual_shared
)
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
...
@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
nchunks
=
math
.
ceil
(
seq_len
/
chunk_size
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
...
...
@@ -360,8 +382,7 @@ if __name__ == "__main__":
D
=
torch
.
randn
(
heads
).
half
().
cuda
()
print
(
"Benchmarking Triton..."
)
triton_latency
=
do_bench
(
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
triton_latency
=
do_bench
(
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
print
(
f
"Triton TFlops:
{
total_flops
/
triton_latency
*
1e-9
}
"
)
print
(
"Benchmarking Helion..."
)
...
...
benchmark/matmul/benchmark_matmul.py
View file @
29051439
...
...
@@ -6,6 +6,7 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
from
tilelang
import
jit
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
...
...
@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -159,9 +160,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
benchmark/matmul/benchmark_matmul_intrinsic.py
View file @
29051439
...
...
@@ -6,7 +6,8 @@ import tilelang as tl
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.autotuner
import
autotune
import
itertools
...
...
@@ -103,12 +104,11 @@ def tl_matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -116,10 +116,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
...
...
@@ -127,7 +129,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -137,7 +138,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for
config
in
configs
:
print
(
config
)
else
:
iter_params
=
dict
(
block_row_warps
=
[
1
,
2
,
4
],
block_col_warps
=
[
1
,
2
,
4
],
...
...
@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage
=
[
0
,
2
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
ref_prog
=
ref_program
,
skip_check
=
True
,
)
@
tl
.
jit
(
out_idx
=
[
2
],)
@
tl
.
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -291,13 +290,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--with_roller"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
parser
.
add_argument
(
"--with_roller"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
benchmark/matmul/benchmark_matmul_sp.py
View file @
29051439
...
...
@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num
,
policy
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -81,7 +82,8 @@ def get_configs(M, N, K):
"thread_num"
:
c
[
4
],
"policy"
:
c
[
5
],
"enable_rasterization"
:
c
[
6
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
in_dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
}
)
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
# Load a sub-block of A from global memory into A_shared
...
...
@@ -241,18 +243,13 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--disable_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--bench_torch_sparse"
,
type
=
str
,
choices
=
[
'
cutlass
'
,
'
cusparselt
'
],
choices
=
[
"
cutlass
"
,
"
cusparselt
"
],
default
=
None
,
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -274,7 +271,8 @@ if __name__ == "__main__":
if
args
.
bench_torch_sparse
is
not
None
:
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
if
args
.
bench_torch_sparse
==
'cutlass'
:
if
args
.
bench_torch_sparse
==
"cutlass"
:
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
A_sp
=
to_sparse_semi_structured
(
A
,
transposed
=
False
)
torch_sparse_latency
=
do_bench
(
lambda
:
A_sp
@
B
)
...
...
@@ -285,8 +283,6 @@ if __name__ == "__main__":
print
(
f
"Best config:
{
best_config
}
"
)
if
args
.
bench_torch_sparse
is
not
None
:
print
(
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Reference Dense TFlops:
{
total_flops
/
ref_latency
*
1e-9
:.
3
f
}
"
)
benchmark/matmul_fp8/benchmark_matmul.py
View file @
29051439
...
...
@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
...
...
@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
M
,
N
,
...
...
@@ -164,9 +164,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -181,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
docs/conf.py
View file @
29051439
...
...
@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension"
,
]
autoapi_type
=
'
python
'
autoapi_dirs
=
[
'
../tilelang
'
]
autoapi_type
=
"
python
"
autoapi_dirs
=
[
"
../tilelang
"
]
autoapi_options
=
[
'
members
'
,
'
undoc-members
'
,
'
show-inheritance
'
,
'
show-module-summary
'
,
'
special-members
'
,
"
members
"
,
"
undoc-members
"
,
"
show-inheritance
"
,
"
show-module-summary
"
,
"
special-members
"
,
]
autoapi_keep_files
=
False
# Useful for debugging the generated rst files
autoapi_generate_api_docs
=
True
autodoc_typehints
=
'
description
'
autodoc_typehints
=
"
description
"
autoapi_ignore
=
[
"*language/ast*"
,
"*version*"
,
"*libinfo*"
,
"*parser*"
]
source_suffix
=
{
'.rst'
:
'restructuredtext'
,
'.md'
:
'markdown'
,
}
source_suffix
=
{
".rst"
:
"restructuredtext"
,
".md"
:
"markdown"
}
myst_enable_extensions
=
[
"colon_fence"
,
"deflist"
,
]
myst_enable_extensions
=
[
"colon_fence"
,
"deflist"
]
redirects
=
{
"get_started/try_out"
:
"../index.html#getting-started"
}
...
...
@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
footer_copyright
=
"© 2025-2026 TileLang"
footer_note
=
" "
html_theme_options
=
{
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
,
}
html_theme_options
=
{
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
}
header_links
=
[
(
"Home"
,
"https://github.com/tile-ai/tilelang"
),
...
...
examples/amd/example_amd_flash_attn_bwd.py
View file @
29051439
...
...
@@ -11,22 +11,20 @@ import time
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K_ref
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V_ref
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K_ref
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K_ref
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V_ref
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V_ref
)
lse
=
torch
.
logsumexp
(
scores
,
dim
=-
1
).
float
()
return
output
,
lse
...
...
@@ -45,23 +43,23 @@ def get_fwd_configs():
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
({
"block_
M
"
:
m
,
"block_N
"
:
n
,
"num_split_q
"
:
s
,
"threads"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization
"
:
r
,
"k_pack
"
:
k
,
"panel_size"
:
p
,
"qk
_coalesced_width"
:
qk
w
,
"v_coalesced_width"
:
vw
,
}
)
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_
N
"
:
n
,
"num_split_q
"
:
s
,
"threads
"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization"
:
r
,
"k_pack
"
:
k
,
"panel_size
"
:
p
,
"qk_coalesced_width"
:
qkw
,
"v
_coalesced_width"
:
v
w
,
}
)
return
valid_configs
...
...
@@ -85,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -97,11 +95,11 @@ def fast_flashattn(
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
LSE
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
LSE
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
):
with
T
.
Kernel
(
num_split_q
,
batch
*
heads
,
threads
=
threads
)
as
(
b_split
,
byz_combined
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
(
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
...
...
@@ -216,8 +202,7 @@ def fast_flashattn(
for
i
in
T
.
Parallel
(
block_M
):
if
q_block_offset
+
i
<
seq_len
:
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
LSE
[
bz
,
by
,
q_block_offset
+
i
]
=
lse_val
bx_loop_var
=
current_bx
+
num_split_q
...
...
@@ -234,16 +219,17 @@ def get_bwd_configs():
panel_size
=
[
7
,
8
,
9
,
10
]
configs
=
[]
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
enable_rasterization
,
panel_size
):
configs
.
append
({
"block_M"
:
m
,
"block_N"
:
n
,
"num_stages"
:
stages
,
"threads"
:
t
,
"enable_rasterization"
:
r
,
"panel_size"
:
p
,
})
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
enable_rasterization
,
panel_size
):
configs
.
append
(
{
"block_M"
:
m
,
"block_N"
:
n
,
"num_stages"
:
stages
,
"threads"
:
t
,
"enable_rasterization"
:
r
,
"panel_size"
:
p
,
}
)
return
configs
...
...
@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
with
T
.
Kernel
(
batch
,
heads
,
T
.
ceildiv
(
seq_len
,
blk
))
as
(
bz
,
bx
,
by
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
...
...
@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
autotune
(
configs
=
get_bwd_configs
(),
cache_input_tensors
=
True
)
@
tilelang
.
jit
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
,
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd_kernel
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
)):
def
flash_bwd_kernel
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
...
...
@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q_shared
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
exp
(
qkT
[
i
,
j
]
*
sm_scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
],
0.0
)
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
],
0.0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
clear
(
dP
)
T
.
gemm
(
V_shared
,
do_shared
,
dP
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T
.
copy
(
P_acc
,
p_cast
)
T
.
gemm
(
p_cast
,
do_shared
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
p_cast
[
i
,
j
]
=
P_acc
[
i
,
j
]
*
(
dP
[
i
,
j
]
-
delta_shared
[
j
])
*
sm_scale
...
...
@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def
flash_bwd_post
(
dQ_in
:
T
.
Tensor
(
shape
,
accum_dtype
),
dQ_out
:
T
.
Tensor
(
shape
,
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
copy
(
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
...
...
@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return
np
.
median
(
times
)
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
device
=
"cuda"
dtype
=
torch
.
float16
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
print
(
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
print
(
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
flops_per_gemm
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
5
*
flops_per_gemm
...
...
@@ -517,22 +508,19 @@ def main(batch: int = 1,
o_ref
.
backward
(
dO
)
print
(
"Verifying backward pass correctness..."
)
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dq_close
:
print
(
"dQ is correct."
)
else
:
print
(
"dQ mismatch detected."
)
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dk_close
:
print
(
"dK is correct."
)
else
:
print
(
"dK mismatch detected."
)
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dv_close
:
print
(
"dV is correct."
)
else
:
...
...
@@ -553,9 +541,7 @@ def main(batch: int = 1,
torch
.
cuda
.
synchronize
()
ref_latency
=
benchmark_function
(
run_reference_fwd_bwd
,
warmup
=
10
,
repeat
=
100
)
print
(
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
print
(
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
def
run_complete_fwd_bwd
():
o_tl_bench
,
lse_tl_bench
=
fwd_kernel
(
q
,
k
,
v
)
...
...
@@ -593,12 +579,12 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
1024
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
1024
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/amd/example_amd_flash_attn_fwd.py
View file @
29051439
...
...
@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors
=
[]
for
param
in
params
:
if
hasattr
(
param
,
'
shape
'
)
and
hasattr
(
param
,
'
dtype
'
):
if
hasattr
(
param
,
"
shape
"
)
and
hasattr
(
param
,
"
dtype
"
):
# Force creation on GPU device
shape
=
[
int
(
s
)
for
s
in
param
.
shape
]
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
'
cuda
'
)
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
"
cuda
"
)
tensors
.
append
(
tensor
)
else
:
tensors
.
append
(
param
)
...
...
@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -58,23 +56,23 @@ def get_configs():
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
({
"block_
M
"
:
m
,
"block_N
"
:
n
,
"num_split_q
"
:
s
,
"threads"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization
"
:
r
,
"k_pack
"
:
k
,
"panel_size"
:
p
,
"qk
_coalesced_width"
:
qk
w
,
"v_coalesced_width"
:
vw
,
}
)
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_
N
"
:
n
,
"num_split_q
"
:
s
,
"threads
"
:
t
,
"num_stages"
:
stages
,
"enable_rasterization"
:
r
,
"k_pack
"
:
k
,
"panel_size
"
:
p
,
"qk_coalesced_width"
:
qkw
,
"v
_coalesced_width"
:
v
w
,
}
)
return
valid_configs
...
...
@@ -98,7 +96,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
...
@@ -110,10 +108,10 @@ def fast_flashattn(
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
num_split_q
,
batch
*
heads
,
threads
=
threads
)
as
(
b_split
,
byz_combined
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
...
@@ -147,32 +145,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
...
...
@@ -222,13 +209,7 @@ def fast_flashattn(
return
main
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
...
...
@@ -250,18 +231,16 @@ def main(batch: int = 1,
print
(
f
"Reference (PyTorch):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
print
(
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/analyze/example_conv_analyze.py
View file @
29051439
...
...
@@ -25,22 +25,7 @@ def check_hopper():
return
False
def
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -50,13 +35,11 @@ def kernel(N,
@
T
.
prim_func
def
conv
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
out_shared
:
make_swizzled_layout
(
out_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
make_swizzled_layout
(
out_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -81,10 +66,8 @@ def kernel(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
examples/analyze/example_gemm_analyze.py
View file @
29051439
...
...
@@ -20,9 +20,9 @@ def kernel(
@
T
.
prim_func
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/attention_sink/benchmark_gqa_sink_fwd.py
View file @
29051439
...
...
@@ -51,8 +51,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
...
@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
...
...
@@ -137,12 +137,11 @@ def main(
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -170,15 +169,14 @@ def main(
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
if
torch
.
allclose
(
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
):
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"Checks for triton passed.✅"
)
else
:
print
(
"Checks for triton failed.❌"
)
...
...
@@ -198,20 +196,14 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/benchmark_mha_sink_fwd.py
View file @
29051439
...
...
@@ -50,8 +50,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
...
@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -163,15 +164,14 @@ def main(batch: int = 1,
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
warmup
=
500
)
...
...
@@ -184,19 +184,13 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
View file @
29051439
This diff is collapsed.
Click to expand it.
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
View file @
29051439
...
...
@@ -23,9 +23,11 @@ def get_configs():
rep
=
100
,
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
...
...
@@ -41,12 +43,11 @@ def flashattn(
threads
=
256
,
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
...
...
@@ -68,13 +69,12 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -89,18 +89,18 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -112,8 +112,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -128,19 +127,19 @@ def flashattn(
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Sinks
:
T
.
Tensor
([
heads
],
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Sinks
:
T
.
Tensor
([
heads
],
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
...
@@ -157,58 +156,58 @@ def flashattn(
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
T
.
annotate_layout
({
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
})
T
.
annotate_layout
(
{
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
in
T
.
Parallel
(
block_M
):
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]]):
start
,
end
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
...
...
@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
'cuda'
)
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
"cuda"
)
return
query
,
key
,
value
,
sinks
...
...
@@ -277,12 +268,11 @@ def main(
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
...
...
@@ -310,15 +300,14 @@ def main(
block_N
=
block_N
,
num_stages
=
num_stages
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
# Benchmark tilelang
...
...
@@ -329,20 +318,14 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
'--window_size'
,
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
View file @
29051439
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment