Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
779 additions
and
937 deletions
+779
-937
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-12
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
...ark/blocksparse_attention/benchmark_library_dense_fmha.py
+5
-8
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
...ksparse_attention/benchmark_tilelang_block_sparse_fmha.py
+31
-41
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
...locksparse_attention/benchmark_torch_block_sparse_fmha.py
+10
-15
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
...ocksparse_attention/benchmark_triton_block_sparse_fmha.py
+10
-30
benchmark/mamba2/benchmark_mamba_chunk_scan.py
benchmark/mamba2/benchmark_mamba_chunk_scan.py
+114
-93
benchmark/matmul/benchmark_matmul.py
benchmark/matmul/benchmark_matmul.py
+8
-8
benchmark/matmul/benchmark_matmul_intrinsic.py
benchmark/matmul/benchmark_matmul_intrinsic.py
+17
-23
benchmark/matmul/benchmark_matmul_sp.py
benchmark/matmul/benchmark_matmul_sp.py
+24
-28
benchmark/matmul_fp8/benchmark_matmul.py
benchmark/matmul_fp8/benchmark_matmul.py
+7
-8
docs/conf.py
docs/conf.py
+11
-20
examples/amd/example_amd_flash_attn_bwd.py
examples/amd/example_amd_flash_attn_bwd.py
+96
-110
examples/amd/example_amd_flash_attn_fwd.py
examples/amd/example_amd_flash_attn_fwd.py
+42
-63
examples/analyze/example_conv_analyze.py
examples/analyze/example_conv_analyze.py
+14
-31
examples/analyze/example_gemm_analyze.py
examples/analyze/example_gemm_analyze.py
+3
-3
examples/attention_sink/benchmark_gqa_sink_fwd.py
examples/attention_sink/benchmark_gqa_sink_fwd.py
+20
-28
examples/attention_sink/benchmark_mha_sink_fwd.py
examples/attention_sink/benchmark_mha_sink_fwd.py
+29
-35
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
+136
-150
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
+75
-92
examples/attention_sink/example_mha_sink_bwd_bhsd.py
examples/attention_sink/example_mha_sink_bwd_bhsd.py
+125
-139
No files found.
.pre-commit-config.yaml
View file @
29051439
...
@@ -39,19 +39,9 @@ repos:
...
@@ -39,19 +39,9 @@ repos:
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.14.7
# sync with requirements-lint.txt
rev
:
v0.14.7
# sync with requirements-lint.txt
hooks
:
hooks
:
-
id
:
ruff-format
-
id
:
ruff-check
-
id
:
ruff-check
args
:
[
--fix
,
--exit-non-zero-on-fix
]
args
:
[
--fix
,
--exit-non-zero-on-fix
]
-
repo
:
https://github.com/google/yapf
rev
:
v0.43.0
# sync with requirements-lint.txt
hooks
:
-
id
:
yapf
name
:
yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args
:
[
--in-place
,
docs/conf.py
]
always_run
:
true
pass_filenames
:
false
-
id
:
yapf
args
:
[
--recursive
,
--in-place
]
-
repo
:
https://github.com/codespell-project/codespell
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
# sync with requirements-lint.txt
rev
:
v2.4.1
# sync with requirements-lint.txt
hooks
:
hooks
:
...
...
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
View file @
29051439
...
@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
...
@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Config
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
import
flash_attn
import
flash_attn
...
...
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
View file @
29051439
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
block_N
=
64
num_stages
=
2
num_stages
=
2
threads
=
128
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
@
T
.
macro
def
MMA0
(
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
...
@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -79,7 +74,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -79,7 +74,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
...
@@ -130,8 +125,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -130,8 +125,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]:
if
block_mask
[
k
]:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
...
@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def
benchmark_topk_sparse_attention
():
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Config
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
program
=
blocksparse_flashattn
(
program
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
4
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
4
)
def
benchmark_fn
():
def
benchmark_fn
():
...
...
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
View file @
29051439
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
...
@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def
benchmark_topk_sparse_attention
():
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Config
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
def
benchmark_fn
():
def
benchmark_fn
():
# Compute reference
# Compute reference
# Expand block mask to full attention matrix
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
return
ref_output
return
ref_output
ref_latency
=
do_bench
(
ref_latency
=
do_bench
(
...
...
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
View file @
29051439
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
if
mask_val
==
True
:
if
mask_val
==
True
:
...
@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
...
@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
float
(
'-inf'
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
qk
-=
m_ij
[:,
None
]
...
@@ -153,7 +148,7 @@ def _fwd_kernel(
...
@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
@@ -191,24 +186,12 @@ def _fwd_kernel(
...
@@ -191,24 +186,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
@@ -253,7 +236,6 @@ def _forward(ctx,
...
@@ -253,7 +236,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
# shape constraints
...
@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
...
@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def
benchmark_topk_sparse_attention
():
def
benchmark_topk_sparse_attention
():
from
benchmark_configs
import
configs
from
benchmark_configs
import
configs
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Config
# Config
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
for
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
TOPK
,
BLOCK
in
configs
:
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
benchmark/mamba2/benchmark_mamba_chunk_scan.py
View file @
29051439
...
@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
...
@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
dt_segment_sum
=
dA_cumsum
[:,
:,
:,
:,
None
]
-
dA_cumsum
[:,
:,
:,
None
,
:]
decay
=
torch
.
exp
(
dt_segment_sum
)
decay
=
torch
.
exp
(
dt_segment_sum
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
scores_decay
=
cb
*
rearrange
(
decay
,
"b h c l s -> b c h l s"
)
causal_mask
=
torch
.
tril
(
causal_mask
=
torch
.
tril
(
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
torch
.
ones
(
chunk_size
,
chunk_size
,
device
=
x
.
device
,
dtype
=
bool
),
diagonal
=
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
scores_decay
=
scores_decay
.
masked_fill
(
~
causal_mask
,
0
)
out
=
torch
.
einsum
(
'bchls,bhcs,bcshp->bclhp'
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
out
=
torch
.
einsum
(
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
))
"bchls,bhcs,bcshp->bclhp"
,
scores_decay
.
to
(
x
.
dtype
),
dt
.
to
(
x
.
dtype
),
rearrange
(
x
,
"b (c s) h p -> b c s h p"
,
c
=
nchunks
)
)
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
state_decay_out
=
torch
.
exp
(
rearrange
(
dA_cumsum
,
"b h c l -> b c l h 1"
))
out_prev
=
torch
.
einsum
(
'bclhn,bchpn->bclhp'
,
rearrange
(
out_prev
=
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
torch
.
einsum
(
"bclhn,bchpn->bclhp"
,
rearrange
(
C
,
"b (c l) h n -> b c l h n"
,
c
=
nchunks
),
prev_states
.
to
(
C
.
dtype
))
*
state_decay_out
)
out
=
out
+
out_prev
out
=
out
+
out_prev
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
out
=
rearrange
(
out
,
"b c l h p -> b (c l) h p"
)
if
D
is
not
None
:
if
D
is
not
None
:
...
@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def
chunk_scan_helion
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
def
chunk_scan_helion
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
):
@
helion
.
kernel
()
@
helion
.
kernel
()
def
helion_mamba2_chunk_scan_kernel
(
def
helion_mamba2_chunk_scan_kernel
(
cb
:
torch
.
Tensor
,
cb
:
torch
.
Tensor
,
...
@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype
=
cb
.
dtype
dtype
=
cb
.
dtype
accum_dtype
=
torch
.
float32
accum_dtype
=
torch
.
float32
assert
(
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
assert
x
.
dtype
==
dt
.
dtype
==
dA_cumsum
.
dtype
==
C
.
dtype
==
prev_states
.
dtype
==
D
.
dtype
==
dtype
dtype
)
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
...
@@ -130,8 +129,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -130,8 +129,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
block_size
=
[
1
,
block_m
,
block_n
,
1
,
1
],
):
):
acc_o
=
hl
.
zeros
([
tile_m
,
tile_n
],
dtype
=
accum_dtype
)
acc_o
=
hl
.
zeros
([
tile_m
,
tile_n
],
dtype
=
accum_dtype
)
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
dA_cumsum_local_m
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_m
].
to
(
torch
.
float32
)
tile_m
].
to
(
torch
.
float32
)
scale_m_local
=
torch
.
exp2
(
dA_cumsum_local_m
*
p
)
scale_m_local
=
torch
.
exp2
(
dA_cumsum_local_m
*
p
)
C_local
=
C
[
C_local
=
C
[
...
@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m
,
tile_m
,
tile_k
,
tile_k
,
]
]
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
dA_cumsum_local_k
=
dA_cumsum
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
tile_k
].
to
(
torch
.
float32
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
cb_local
*=
torch
.
exp2
(
dA_cumsum_local_m
[:,
None
]
*
p
-
dA_cumsum_local_k
[
None
,
:]
*
p
)
dt_local
=
dt
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
dt_local
=
dt
[
tile_b
.
begin
,
tile_h
.
begin
,
tile_c
.
begin
,
tile_k
].
to
(
torch
.
float32
)
cb_local
=
(
cb_local
*
dt_local
[
None
,
:]).
to
(
dtype
)
cb_local
=
(
cb_local
*
dt_local
[
None
,
:]).
to
(
dtype
)
pred
=
(
tile_m
.
index
+
0
)[:,
None
]
>=
(
tile_k
.
index
+
0
)[
None
,
:]
pred
=
(
tile_m
.
index
+
0
)[:,
None
]
>=
(
tile_k
.
index
+
0
)[
None
,
:]
...
@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o
=
hl
.
dot
(
cb_local
,
x_local
,
acc
=
acc_o
)
acc_o
=
hl
.
dot
(
cb_local
,
x_local
,
acc
=
acc_o
)
D_local
=
D
[
tile_h
.
begin
].
to
(
torch
.
float32
)
D_local
=
D
[
tile_h
.
begin
].
to
(
torch
.
float32
)
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
x_residual
=
x
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
].
to
(
torch
.
float32
)
tile_n
].
to
(
torch
.
float32
)
acc_o
+=
x_residual
*
D_local
acc_o
+=
x_residual
*
D_local
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
out
[
tile_b
.
begin
,
tile_c
.
begin
*
chunk_size
+
tile_m
.
index
,
tile_h
.
begin
,
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
tile_n
]
=
acc_o
.
to
(
dtype
=
dtype
)
return
out
return
out
...
@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
...
@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def
get_configs
():
def
get_configs
():
iter_params
=
dict
(
iter_params
=
dict
(
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
block_M
=
[
64
,
128
,
256
],
block_N
=
[
32
,
64
],
block_K
=
[
64
,
128
,
256
],
block_Dstate
=
[
128
],
num_stages
=
[
1
,
2
,
3
,
4
,
5
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
@@ -198,7 +187,8 @@ def get_configs():
...
@@ -198,7 +187,8 @@ def get_configs():
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
},
)
)
def
chunk_scan_fwd
(
batch
,
def
chunk_scan_fwd
(
batch
,
seqlen
,
seqlen
,
chunk_size
,
chunk_size
,
ngroups
,
ngroups
,
...
@@ -210,7 +200,8 @@ def chunk_scan_fwd(batch,
...
@@ -210,7 +200,8 @@ def chunk_scan_fwd(batch,
block_K
=
64
,
block_K
=
64
,
block_Dstate
=
128
,
block_Dstate
=
128
,
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
):
threads
=
128
,
):
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
nchunks
=
T
.
ceildiv
(
seqlen
,
chunk_size
)
...
@@ -225,13 +216,13 @@ def chunk_scan_fwd(batch,
...
@@ -225,13 +216,13 @@ def chunk_scan_fwd(batch,
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
,
):
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
batch
*
nchunks
,
threads
=
threads
)
as
(
bz
,
bx
,
by
):
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
acc_o_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
cb_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"shared.dyn"
)
...
@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
...
@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
m_idx
=
bx
//
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
n_idx
=
bx
%
T
.
ceildiv
(
headdim
,
block_N
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
acc_o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
acc_o_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
cb_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
cb_shared
),
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
),
})
}
)
T
.
no_set_max_nreg
()
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
T
.
clear
(
acc_o
)
T
.
clear
(
acc_o
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
scale_m_local
[
i
]
=
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
)
T
.
copy
(
T
.
copy
(
C
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
C
[
(
m_idx
+
1
)
*
block_M
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
C_shared
)
batch_idx
,
T
.
copy
(
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
,
bz
//
(
nheads
//
ngroups
),
0
:
block_Dstate
],
prev_state_shared
)
0
:
block_Dstate
,
],
C_shared
,
)
T
.
copy
(
prev_states
[
batch_idx
,
chunk_idx
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
0
:
block_Dstate
],
prev_state_shared
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
T
.
gemm
(
C_shared
,
prev_state_shared
,
acc_o
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
acc_o
[
i
,
j
]
*=
scale_m_local
[
i
]
...
@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
...
@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
cb
[
batch_idx
,
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
cb
[
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
batch_idx
,
cb_shared
)
chunk_idx
,
bz
//
(
nheads
//
ngroups
),
m_idx
*
block_M
:
(
m_idx
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
],
cb_shared
,
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
cb_shared
,
cb_local
)
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dA_cs_k_shared
)
dA_cs_k_shared
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
T
.
copy
(
dA_cs_k_shared
,
dA_cs_k_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
cb_local
[
i
,
j
]
=
cb_local
[
i
,
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
j
]
=
cb_local
[
i
,
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
dt_shared
)
j
]
*
T
.
exp2
(
dA_cs_m_local
[
i
]
*
p
-
dA_cs_k_local
[
j
]
*
p
)
T
.
copy
(
dt
[
batch_idx
,
bz
,
chunk_idx
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
dt_shared
)
T
.
copy
(
dt_shared
,
dt_local
)
T
.
copy
(
dt_shared
,
dt_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
cb_local
[
i
,
j
]
*=
dt_local
[
j
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_K
):
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
]
=
T
.
if_then_else
(
m_idx
*
block_M
+
i
>=
k
*
block_K
+
j
,
cb_local
[
i
,
j
],
0
)
cb_local
[
i
,
j
],
0
)
T
.
copy
(
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
x
[
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
x_shared
)
batch_idx
,
chunk_idx
*
chunk_size
+
k
*
block_K
:
chunk_idx
*
chunk_size
+
(
k
+
1
)
*
block_K
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_shared
,
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
T
.
gemm
(
cb_local
,
x_shared
,
acc_o
)
D_local
[
0
]
=
D
[
bz
]
D_local
[
0
]
=
D
[
bz
]
T
.
copy
(
T
.
copy
(
x
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
x
[
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
],
batch_idx
,
x_residual_shared
)
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
x_residual_shared
,
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
T
.
copy
(
x_residual_shared
,
x_residual_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
acc_o
[
i
,
j
]
+=
x_residual_local
[
i
,
j
]
*
D_local
[
0
]
...
@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
...
@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
acc_o
,
acc_o_shared
)
T
.
copy
(
T
.
copy
(
acc_o_shared
,
acc_o_shared
,
Output
[
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
Output
[
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:(
n_idx
+
1
)
*
block_N
])
batch_idx
,
chunk_idx
*
chunk_size
+
m_idx
*
block_M
:
chunk_idx
*
chunk_size
+
(
m_idx
+
1
)
*
block_M
,
bz
,
n_idx
*
block_N
:
(
n_idx
+
1
)
*
block_N
,
],
)
return
main
return
main
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
80
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
80
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
'
--chunk_size
'
,
type
=
int
,
default
=
256
,
help
=
'
chunk size
'
)
parser
.
add_argument
(
"
--chunk_size
"
,
type
=
int
,
default
=
256
,
help
=
"
chunk size
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--dstate
'
,
type
=
int
,
default
=
128
,
help
=
'
dstate
'
)
parser
.
add_argument
(
"
--dstate
"
,
type
=
int
,
default
=
128
,
help
=
"
dstate
"
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
batch
,
heads
,
groups
,
seq_len
,
chunk_size
,
dim
,
dstate
=
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
seq_len
,
args
.
chunk_size
,
args
.
dim
,
args
.
dstate
,
)
nchunks
=
math
.
ceil
(
seq_len
/
chunk_size
)
nchunks
=
math
.
ceil
(
seq_len
/
chunk_size
)
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
total_flops
=
2
*
batch
*
seq_len
*
chunk_size
*
heads
*
dim
*
0.5
+
2
*
batch
*
seq_len
*
heads
*
dim
*
dstate
...
@@ -360,8 +382,7 @@ if __name__ == "__main__":
...
@@ -360,8 +382,7 @@ if __name__ == "__main__":
D
=
torch
.
randn
(
heads
).
half
().
cuda
()
D
=
torch
.
randn
(
heads
).
half
().
cuda
()
print
(
"Benchmarking Triton..."
)
print
(
"Benchmarking Triton..."
)
triton_latency
=
do_bench
(
triton_latency
=
do_bench
(
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
lambda
:
chunk_scan_triton
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
),
_n_warmup
=
10
,
_n_repeat
=
10
)
print
(
f
"Triton TFlops:
{
total_flops
/
triton_latency
*
1e-9
}
"
)
print
(
f
"Triton TFlops:
{
total_flops
/
triton_latency
*
1e-9
}
"
)
print
(
"Benchmarking Helion..."
)
print
(
"Benchmarking Helion..."
)
...
...
benchmark/matmul/benchmark_matmul.py
View file @
29051439
...
@@ -6,6 +6,7 @@ import tilelang
...
@@ -6,6 +6,7 @@ import tilelang
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
from
tilelang.autotuner
import
autotune
from
tilelang
import
jit
from
tilelang
import
jit
# Configure logger
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
setLevel
(
logging
.
DEBUG
)
...
@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
...
@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
enable_rasteration
=
[
True
,
False
],
)
)
return
[{
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
return
configs
...
@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
...
@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
warmup
=
3
,
rep
=
20
,
rep
=
20
,
)
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
def
matmul
(
M
,
M
,
N
,
N
,
...
@@ -176,7 +177,6 @@ def matmul(
...
@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
benchmark/matmul/benchmark_matmul_intrinsic.py
View file @
29051439
...
@@ -6,7 +6,8 @@ import tilelang as tl
...
@@ -6,7 +6,8 @@ import tilelang as tl
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
from
tilelang.autotuner
import
autotune
from
tilelang.autotuner
import
autotune
import
itertools
import
itertools
...
@@ -108,7 +109,6 @@ def tl_matmul(
...
@@ -108,7 +109,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -116,10 +116,12 @@ def tl_matmul(
...
@@ -116,10 +116,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
...
@@ -127,7 +129,6 @@ def tl_matmul(
...
@@ -127,7 +129,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -137,7 +138,6 @@ def tl_matmul(
...
@@ -137,7 +138,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
...
@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for
config
in
configs
:
for
config
in
configs
:
print
(
config
)
print
(
config
)
else
:
else
:
iter_params
=
dict
(
iter_params
=
dict
(
block_row_warps
=
[
1
,
2
,
4
],
block_row_warps
=
[
1
,
2
,
4
],
block_col_warps
=
[
1
,
2
,
4
],
block_col_warps
=
[
1
,
2
,
4
],
...
@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
...
@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage
=
[
0
,
2
],
stage
=
[
0
,
2
],
enable_rasteration
=
[
True
,
False
],
enable_rasteration
=
[
True
,
False
],
)
)
return
[{
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
return
configs
...
@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
...
@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
ref_prog
=
ref_program
,
ref_prog
=
ref_program
,
skip_check
=
True
,
skip_check
=
True
,
)
)
@
tl
.
jit
(
out_idx
=
[
2
],)
@
tl
.
jit
(
out_idx
=
[
2
],
)
def
matmul
(
def
matmul
(
M
,
M
,
N
,
N
,
...
@@ -291,13 +290,8 @@ if __name__ == "__main__":
...
@@ -291,13 +290,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--with_roller"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
"--with_roller"
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
type
=
bool
,
default
=
False
,
help
=
"Whether to use roller to deduce search spaces"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int8"
],
help
=
"Input data type"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
benchmark/matmul/benchmark_matmul_sp.py
View file @
29051439
...
@@ -70,7 +70,8 @@ def get_configs(M, N, K):
...
@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num
,
thread_num
,
policy
,
policy
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -81,7 +82,8 @@ def get_configs(M, N, K):
...
@@ -81,7 +82,8 @@ def get_configs(M, N, K):
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"policy"
:
c
[
5
],
"policy"
:
c
[
5
],
"enable_rasterization"
:
c
[
6
],
# keep param name for backward-compat
"enable_rasterization"
:
c
[
6
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
...
@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
...
@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
warmup
=
3
,
warmup
=
3
,
rep
=
20
,
rep
=
20
,
)
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
...
@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
"""
# Bind x-dimension to block index in N,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
# y-dimension to block index in M.
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
in_dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
in_dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
...
@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
(
{
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
E_shared
:
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
block_k
=
block_K
),
}
}
)
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
# Loop over sub-blocks in K dimension, pipelined by num_stages
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
# Load a sub-block of A from global memory into A_shared
# Load a sub-block of A from global memory into A_shared
...
@@ -241,18 +243,13 @@ if __name__ == "__main__":
...
@@ -241,18 +243,13 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--disable_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--disable_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--bench_torch_sparse"
,
"--bench_torch_sparse"
,
type
=
str
,
type
=
str
,
choices
=
[
'
cutlass
'
,
'
cusparselt
'
],
choices
=
[
"
cutlass
"
,
"
cusparselt
"
],
default
=
None
,
default
=
None
,
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
help
=
"Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -274,7 +271,8 @@ if __name__ == "__main__":
...
@@ -274,7 +271,8 @@ if __name__ == "__main__":
if
args
.
bench_torch_sparse
is
not
None
:
if
args
.
bench_torch_sparse
is
not
None
:
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
if
args
.
bench_torch_sparse
==
'cutlass'
:
if
args
.
bench_torch_sparse
==
"cutlass"
:
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
A_sp
=
to_sparse_semi_structured
(
A
,
transposed
=
False
)
A_sp
=
to_sparse_semi_structured
(
A
,
transposed
=
False
)
torch_sparse_latency
=
do_bench
(
lambda
:
A_sp
@
B
)
torch_sparse_latency
=
do_bench
(
lambda
:
A_sp
@
B
)
...
@@ -285,8 +283,6 @@ if __name__ == "__main__":
...
@@ -285,8 +283,6 @@ if __name__ == "__main__":
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
if
args
.
bench_torch_sparse
is
not
None
:
if
args
.
bench_torch_sparse
is
not
None
:
print
(
print
(
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
f
"Torch sparse (
{
args
.
bench_torch_sparse
}
) TFlops:
{
total_flops
/
torch_sparse_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Reference Dense TFlops:
{
total_flops
/
ref_latency
*
1e-9
:.
3
f
}
"
)
print
(
f
"Reference Dense TFlops:
{
total_flops
/
ref_latency
*
1e-9
:.
3
f
}
"
)
benchmark/matmul_fp8/benchmark_matmul.py
View file @
29051439
...
@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
...
@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
policy
=
[
T
.
GemmWarpPolicy
.
Square
],
enable_rasteration
=
[
True
,
False
],
enable_rasteration
=
[
True
,
False
],
)
)
return
[{
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
configs
return
configs
...
@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
...
@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
warmup
=
3
,
warmup
=
3
,
rep
=
20
,
rep
=
20
,
)
)
@
jit
(
out_idx
=
[
2
],)
@
jit
(
out_idx
=
[
2
],
)
def
matmul
(
def
matmul
(
M
,
M
,
N
,
N
,
...
@@ -181,7 +181,6 @@ def matmul(
...
@@ -181,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
docs/conf.py
View file @
29051439
...
@@ -20,33 +20,27 @@ extensions = [
...
@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension"
,
"autoapi.extension"
,
]
]
autoapi_type
=
'
python
'
autoapi_type
=
"
python
"
autoapi_dirs
=
[
'
../tilelang
'
]
autoapi_dirs
=
[
"
../tilelang
"
]
autoapi_options
=
[
autoapi_options
=
[
'
members
'
,
"
members
"
,
'
undoc-members
'
,
"
undoc-members
"
,
'
show-inheritance
'
,
"
show-inheritance
"
,
'
show-module-summary
'
,
"
show-module-summary
"
,
'
special-members
'
,
"
special-members
"
,
]
]
autoapi_keep_files
=
False
# Useful for debugging the generated rst files
autoapi_keep_files
=
False
# Useful for debugging the generated rst files
autoapi_generate_api_docs
=
True
autoapi_generate_api_docs
=
True
autodoc_typehints
=
'
description
'
autodoc_typehints
=
"
description
"
autoapi_ignore
=
[
"*language/ast*"
,
"*version*"
,
"*libinfo*"
,
"*parser*"
]
autoapi_ignore
=
[
"*language/ast*"
,
"*version*"
,
"*libinfo*"
,
"*parser*"
]
source_suffix
=
{
source_suffix
=
{
".rst"
:
"restructuredtext"
,
".md"
:
"markdown"
}
'.rst'
:
'restructuredtext'
,
'.md'
:
'markdown'
,
}
myst_enable_extensions
=
[
myst_enable_extensions
=
[
"colon_fence"
,
"deflist"
]
"colon_fence"
,
"deflist"
,
]
redirects
=
{
"get_started/try_out"
:
"../index.html#getting-started"
}
redirects
=
{
"get_started/try_out"
:
"../index.html#getting-started"
}
...
@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
...
@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
footer_copyright
=
"© 2025-2026 TileLang"
footer_copyright
=
"© 2025-2026 TileLang"
footer_note
=
" "
footer_note
=
" "
html_theme_options
=
{
html_theme_options
=
{
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
}
"light_logo"
:
"img/logo-v2.png"
,
"dark_logo"
:
"img/logo-v2.png"
,
}
header_links
=
[
header_links
=
[
(
"Home"
,
"https://github.com/tile-ai/tilelang"
),
(
"Home"
,
"https://github.com/tile-ai/tilelang"
),
...
...
examples/amd/example_amd_flash_attn_bwd.py
View file @
29051439
...
@@ -11,22 +11,20 @@ import time
...
@@ -11,22 +11,20 @@ import time
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
dim
=
Q
.
size
(
-
1
)
K_ref
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
K_ref
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V_ref
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
V_ref
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K_ref
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K_ref
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V_ref
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V_ref
)
lse
=
torch
.
logsumexp
(
scores
,
dim
=-
1
).
float
()
lse
=
torch
.
logsumexp
(
scores
,
dim
=-
1
).
float
()
return
output
,
lse
return
output
,
lse
...
@@ -45,12 +43,11 @@ def get_fwd_configs():
...
@@ -45,12 +43,11 @@ def get_fwd_configs():
valid_configs
=
[]
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
threads
,
num_stages
,
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
enable_rasterization
,
k_pack
,
):
panel_size
,
qk_coalesced_width
,
valid_configs
.
append
(
v_coalesced_width
):
{
valid_configs
.
append
({
"block_M"
:
m
,
"block_M"
:
m
,
"block_N"
:
n
,
"block_N"
:
n
,
"num_split_q"
:
s
,
"num_split_q"
:
s
,
...
@@ -61,7 +58,8 @@ def get_fwd_configs():
...
@@ -61,7 +58,8 @@ def get_fwd_configs():
"panel_size"
:
p
,
"panel_size"
:
p
,
"qk_coalesced_width"
:
qkw
,
"qk_coalesced_width"
:
qkw
,
"v_coalesced_width"
:
vw
,
"v_coalesced_width"
:
vw
,
})
}
)
return
valid_configs
return
valid_configs
...
@@ -85,7 +83,7 @@ def fast_flashattn(
...
@@ -85,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
@@ -135,33 +133,21 @@ def fast_flashattn(
...
@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
(
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
kv_idx
=
k
*
block_N
T
.
copy
(
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
...
@@ -216,8 +202,7 @@ def fast_flashattn(
...
@@ -216,8 +202,7 @@ def fast_flashattn(
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
q_block_offset
+
i
<
seq_len
:
if
q_block_offset
+
i
<
seq_len
:
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
lse_val
=
T
.
if_then_else
(
l_i
[
i
]
>
0
,
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
T
.
log
(
l_i
[
i
])
+
m_i
[
i
],
-
T
.
infinity
(
accum_dtype
))
LSE
[
bz
,
by
,
q_block_offset
+
i
]
=
lse_val
LSE
[
bz
,
by
,
q_block_offset
+
i
]
=
lse_val
bx_loop_var
=
current_bx
+
num_split_q
bx_loop_var
=
current_bx
+
num_split_q
...
@@ -234,16 +219,17 @@ def get_bwd_configs():
...
@@ -234,16 +219,17 @@ def get_bwd_configs():
panel_size
=
[
7
,
8
,
9
,
10
]
panel_size
=
[
7
,
8
,
9
,
10
]
configs
=
[]
configs
=
[]
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
for
m
,
n
,
stages
,
t
,
r
,
p
in
itertools
.
product
(
block_M
,
block_N
,
num_stages
,
threads
,
enable_rasterization
,
panel_size
):
enable_rasterization
,
panel_size
):
configs
.
append
(
configs
.
append
(
{
{
"block_M"
:
m
,
"block_M"
:
m
,
"block_N"
:
n
,
"block_N"
:
n
,
"num_stages"
:
stages
,
"num_stages"
:
stages
,
"threads"
:
t
,
"threads"
:
t
,
"enable_rasterization"
:
r
,
"enable_rasterization"
:
r
,
"panel_size"
:
p
,
"panel_size"
:
p
,
})
}
)
return
configs
return
configs
...
@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
...
@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
blk
=
32
blk
=
32
@
T
.
prim_func
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
dO
:
T
.
Tensor
(
shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
)):
with
T
.
Kernel
(
batch
,
heads
,
T
.
ceildiv
(
seq_len
,
blk
))
as
(
bz
,
bx
,
by
):
with
T
.
Kernel
(
batch
,
heads
,
T
.
ceildiv
(
seq_len
,
blk
))
as
(
bz
,
bx
,
by
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
...
@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
...
@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
return
flash_bwd_prep
@
tilelang
.
autotune
(
configs
=
get_bwd_configs
(),
cache_input_tensors
=
True
)
@
tilelang
.
autotune
(
configs
=
get_bwd_configs
(),
cache_input_tensors
=
True
)
@
tilelang
.
jit
@
tilelang
.
jit
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
def
flashattn_bwd
(
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
):
batch
,
sm_scale
=
(
1.0
/
dim
)
**
0.5
heads
,
seq_len
,
dim
,
is_causal
,
groups
,
block_M
:
int
,
block_N
:
int
,
num_stages
:
int
,
threads
:
int
,
enable_rasterization
:
bool
,
panel_size
:
int
,
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
...
@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
accum_dtype
=
"float"
accum_dtype
=
"float"
@
T
.
prim_func
@
T
.
prim_func
def
flash_bwd_kernel
(
Q
:
T
.
Tensor
(
q_shape
,
def
flash_bwd_kernel
(
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
accum_dtype
),
dO
:
T
.
Tensor
(
q_shape
,
dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
accum_dtype
),
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
)):
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
dK
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
dV
:
T
.
Tensor
(
kv_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
T
.
use_swizzle
(
panel_size
,
enable
=
enable_rasterization
)
...
@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
...
@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
T
.
clear
(
dk
)
...
@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
...
@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q_shared
)
T
.
clear
(
qkT
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q_shared
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
K_shared
,
q_shared
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
exp
(
qkT
[
i
,
j
]
*
sm_scale
-
lse_shared
[
j
])
P_acc
[
i
,
j
]
=
T
.
exp
(
qkT
[
i
,
j
]
*
sm_scale
-
lse_shared
[
j
])
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
P_acc
[
i
,
j
],
0.0
)
P_acc
[
i
,
j
],
0.0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do_shared
)
T
.
clear
(
dP
)
T
.
clear
(
dP
)
T
.
gemm
(
V_shared
,
do_shared
,
dP
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
V_shared
,
do_shared
,
dP
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
...
@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T
.
copy
(
P_acc
,
p_cast
)
T
.
copy
(
P_acc
,
p_cast
)
T
.
gemm
(
p_cast
,
do_shared
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
p_cast
,
do_shared
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
p_cast
[
i
,
j
]
=
P_acc
[
i
,
j
]
*
(
dP
[
i
,
j
]
-
delta_shared
[
j
])
*
sm_scale
p_cast
[
i
,
j
]
=
P_acc
[
i
,
j
]
*
(
dP
[
i
,
j
]
-
delta_shared
[
j
])
*
sm_scale
...
@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
...
@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def
flash_bwd_post
(
dQ_in
:
T
.
Tensor
(
shape
,
accum_dtype
),
dQ_out
:
T
.
Tensor
(
shape
,
dtype
)):
def
flash_bwd_post
(
dQ_in
:
T
.
Tensor
(
shape
,
accum_dtype
),
dQ_out
:
T
.
Tensor
(
shape
,
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
copy
(
T
.
copy
(
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_in
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
)
)
return
flash_bwd_post
return
flash_bwd_post
...
@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
...
@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return
np
.
median
(
times
)
return
np
.
median
(
times
)
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
device
=
"cuda"
device
=
"cuda"
dtype
=
torch
.
float16
dtype
=
torch
.
float16
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
print
(
print
(
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
f
"Test configuration: batch=
{
batch
}
, heads=
{
heads
}
, seq_len=
{
seq_len
}
, dim=
{
dim
}
, is_causal=
{
is_causal
}
, groups=
{
groups
}
"
)
flops_per_gemm
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
flops_per_gemm
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
5
*
flops_per_gemm
total_flops
=
5
*
flops_per_gemm
...
@@ -517,22 +508,19 @@ def main(batch: int = 1,
...
@@ -517,22 +508,19 @@ def main(batch: int = 1,
o_ref
.
backward
(
dO
)
o_ref
.
backward
(
dO
)
print
(
"Verifying backward pass correctness..."
)
print
(
"Verifying backward pass correctness..."
)
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dq_close
,
dq_max_diff
,
dq_mean_diff
=
debug_tensor_comparison
(
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
dQ_tl
,
q_ref
.
grad
,
"dQ"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dq_close
:
if
dq_close
:
print
(
"dQ is correct."
)
print
(
"dQ is correct."
)
else
:
else
:
print
(
"dQ mismatch detected."
)
print
(
"dQ mismatch detected."
)
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dk_close
,
dk_max_diff
,
dk_mean_diff
=
debug_tensor_comparison
(
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
dK_tl
.
to
(
torch
.
float16
),
k_ref
.
grad
,
"dK"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dk_close
:
if
dk_close
:
print
(
"dK is correct."
)
print
(
"dK is correct."
)
else
:
else
:
print
(
"dK mismatch detected."
)
print
(
"dK mismatch detected."
)
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dv_close
,
dv_max_diff
,
dv_mean_diff
=
debug_tensor_comparison
(
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
dV_tl
.
to
(
torch
.
float16
),
v_ref
.
grad
,
"dV"
,
rtol
=
0.05
,
atol
=
0.05
)
if
dv_close
:
if
dv_close
:
print
(
"dV is correct."
)
print
(
"dV is correct."
)
else
:
else
:
...
@@ -553,9 +541,7 @@ def main(batch: int = 1,
...
@@ -553,9 +541,7 @@ def main(batch: int = 1,
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
ref_latency
=
benchmark_function
(
run_reference_fwd_bwd
,
warmup
=
10
,
repeat
=
100
)
ref_latency
=
benchmark_function
(
run_reference_fwd_bwd
,
warmup
=
10
,
repeat
=
100
)
print
(
print
(
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
f
"Reference PyTorch Forward+Backward:
{
ref_latency
:.
2
f
}
ms |
{
total_flops
/
ref_latency
*
1e-9
:.
2
f
}
TFlops"
)
def
run_complete_fwd_bwd
():
def
run_complete_fwd_bwd
():
o_tl_bench
,
lse_tl_bench
=
fwd_kernel
(
q
,
k
,
v
)
o_tl_bench
,
lse_tl_bench
=
fwd_kernel
(
q
,
k
,
v
)
...
@@ -593,12 +579,12 @@ def main(batch: int = 1,
...
@@ -593,12 +579,12 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
1024
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
1024
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/amd/example_amd_flash_attn_fwd.py
View file @
29051439
...
@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
...
@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors
=
[]
tensors
=
[]
for
param
in
params
:
for
param
in
params
:
if
hasattr
(
param
,
'
shape
'
)
and
hasattr
(
param
,
'
dtype
'
):
if
hasattr
(
param
,
"
shape
"
)
and
hasattr
(
param
,
"
dtype
"
):
# Force creation on GPU device
# Force creation on GPU device
shape
=
[
int
(
s
)
for
s
in
param
.
shape
]
shape
=
[
int
(
s
)
for
s
in
param
.
shape
]
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
'
cuda
'
)
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
"
cuda
"
)
tensors
.
append
(
tensor
)
tensors
.
append
(
tensor
)
else
:
else
:
tensors
.
append
(
param
)
tensors
.
append
(
param
)
...
@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
...
@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
V heads
{
V
.
size
(
2
)
}
groups
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
return
output
...
@@ -58,12 +56,11 @@ def get_configs():
...
@@ -58,12 +56,11 @@ def get_configs():
valid_configs
=
[]
valid_configs
=
[]
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
block_M
,
block_N
,
num_split_q
,
for
m
,
n
,
s
,
t
,
stages
,
r
,
k
,
p
,
qkw
,
vw
in
itertools
.
product
(
threads
,
num_stages
,
block_M
,
block_N
,
num_split_q
,
threads
,
num_stages
,
enable_rasterization
,
k_pack
,
panel_size
,
qk_coalesced_width
,
v_coalesced_width
enable_rasterization
,
k_pack
,
):
panel_size
,
qk_coalesced_width
,
valid_configs
.
append
(
v_coalesced_width
):
{
valid_configs
.
append
({
"block_M"
:
m
,
"block_M"
:
m
,
"block_N"
:
n
,
"block_N"
:
n
,
"num_split_q"
:
s
,
"num_split_q"
:
s
,
...
@@ -74,7 +71,8 @@ def get_configs():
...
@@ -74,7 +71,8 @@ def get_configs():
"panel_size"
:
p
,
"panel_size"
:
p
,
"qk_coalesced_width"
:
qkw
,
"qk_coalesced_width"
:
qkw
,
"v_coalesced_width"
:
vw
,
"v_coalesced_width"
:
vw
,
})
}
)
return
valid_configs
return
valid_configs
...
@@ -98,7 +96,7 @@ def fast_flashattn(
...
@@ -98,7 +96,7 @@ def fast_flashattn(
qk_coalesced_width
:
int
,
qk_coalesced_width
:
int
,
v_coalesced_width
:
int
,
v_coalesced_width
:
int
,
):
):
scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
...
@@ -147,32 +145,21 @@ def fast_flashattn(
...
@@ -147,32 +145,21 @@ def fast_flashattn(
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
m_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_factor
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
T
.
copy
(
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
Q
[
bz
,
q_block_offset
:
q_block_offset
+
block_M
,
by
,
:],
Q_shared
,
coalesced_width
=
vec_size
)
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
loop_end_k
=
T
.
ceildiv
(
q_block_offset
+
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
row_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_end_k
,
num_stages
=
num_stages
):
kv_idx
=
k
*
block_N
kv_idx
=
k
*
block_N
T
.
copy
(
T
.
copy
(
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
K_shared
,
coalesced_width
=
vec_size
)
K
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
K_shared
,
coalesced_width
=
vec_size
)
T
.
copy
(
V
[
bz
,
kv_idx
:
kv_idx
+
block_N
,
by
//
groups
,
:],
V_shared
,
coalesced_width
=
v_vec_size
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_block_offset
+
i
>=
kv_idx
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
...
@@ -222,13 +209,7 @@ def fast_flashattn(
...
@@ -222,13 +209,7 @@ def fast_flashattn(
return
main
return
main
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
heads
:
int
=
8
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
1
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
if
is_causal
:
...
@@ -250,18 +231,16 @@ def main(batch: int = 1,
...
@@ -250,18 +231,16 @@ def main(batch: int = 1,
print
(
f
"Reference (PyTorch):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
print
(
f
"Reference (PyTorch):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
print
(
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
f
"Fast Flash Attention V2 (Tile-lang):
{
latency
:.
2
f
}
ms |
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
8
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
8
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
1
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
1
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
)
examples/analyze/example_conv_analyze.py
View file @
29051439
...
@@ -25,22 +25,7 @@ def check_hopper():
...
@@ -25,22 +25,7 @@ def check_hopper():
return
False
return
False
def
kernel
(
N
,
def
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
@@ -54,9 +39,7 @@ def kernel(N,
...
@@ -54,9 +39,7 @@ def kernel(N,
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -65,11 +48,13 @@ def kernel(N,
...
@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
out_shared
:
make_swizzled_layout
(
out_shared
),
out_shared
:
make_swizzled_layout
(
out_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
data_shared
:
make_swizzled_layout
(
data_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
kernel_shared
:
make_swizzled_layout
(
kernel_shared
),
})
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -81,10 +66,8 @@ def kernel(N,
...
@@ -81,10 +66,8 @@ def kernel(N,
m
=
by
*
block_M
+
i
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
examples/analyze/example_gemm_analyze.py
View file @
29051439
examples/attention_sink/benchmark_gqa_sink_fwd.py
View file @
29051439
...
@@ -51,8 +51,7 @@ def triton_kernel(
...
@@ -51,8 +51,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
...
@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
return
o
...
@@ -137,12 +137,11 @@ def main(
...
@@ -137,12 +137,11 @@ def main(
):
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -170,15 +169,14 @@ def main(
...
@@ -170,15 +169,14 @@ def main(
block_N
=
block_N
,
block_N
=
block_N
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
threads
=
threads
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
if
torch
.
allclose
(
if
torch
.
allclose
(
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
):
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"Checks for triton passed.✅"
)
print
(
"Checks for triton passed.✅"
)
else
:
else
:
print
(
"Checks for triton failed.❌"
)
print
(
"Checks for triton failed.❌"
)
...
@@ -198,20 +196,14 @@ def main(
...
@@ -198,20 +196,14 @@ def main(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
args
.
dtype
,
args
.
tune
)
examples/attention_sink/benchmark_mha_sink_fwd.py
View file @
29051439
...
@@ -50,8 +50,7 @@ def triton_kernel(
...
@@ -50,8 +50,7 @@ def triton_kernel(
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
q
=
Q
.
load
([
off_z
,
off_h
,
start_m
*
BLOCK_M
,
0
]).
reshape
([
BLOCK_M
,
HEAD_DIM
])
if
BANDWIDTH
:
if
BANDWIDTH
:
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
lo
,
hi
=
tl
.
maximum
(
0
,
start_q
+
start_m
*
BLOCK_M
-
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
BANDWIDTH
),
start_q
+
(
start_m
+
1
)
*
BLOCK_M
else
:
else
:
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
lo
,
hi
=
0
,
start_q
+
(
start_m
+
1
)
*
BLOCK_M
...
@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
...
@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH
=
window_size
,
BANDWIDTH
=
window_size
,
BLOCK_M
=
BLOCK_M
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
start_q
=
seq_kv
-
seq_q
)
start_q
=
seq_kv
-
seq_q
,
)
return
o
return
o
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -163,15 +164,14 @@ def main(batch: int = 1,
...
@@ -163,15 +164,14 @@ def main(batch: int = 1,
block_N
=
block_N
,
block_N
=
block_N
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
threads
=
threads
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
)
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
warmup
=
500
)
latency
=
do_bench
(
lambda
:
triton_program
(
Q
,
K
,
V
,
sinks
,
window_size
),
warmup
=
500
)
...
@@ -184,19 +184,13 @@ def main(batch: int = 1,
...
@@ -184,19 +184,13 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune"
)
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
args
.
tune
)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
View file @
29051439
...
@@ -20,9 +20,11 @@ def get_bwd_configs():
...
@@ -20,9 +20,11 @@ def get_bwd_configs():
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
def
flashattn_fwd
(
batch
,
batch
,
heads
,
heads
,
...
@@ -35,13 +37,13 @@ def flashattn_fwd(
...
@@ -35,13 +37,13 @@ def flashattn_fwd(
block_N
=
64
,
block_N
=
64
,
num_stages
=
1
,
num_stages
=
1
,
threads
=
128
,
threads
=
128
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
...
@@ -73,7 +75,7 @@ def flashattn_fwd(
...
@@ -73,7 +75,7 @@ def flashattn_fwd(
sinks
=
T
.
alloc_fragment
([
heads
],
dtype
)
sinks
=
T
.
alloc_fragment
([
heads
],
dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -81,22 +83,20 @@ def flashattn_fwd(
...
@@ -81,22 +83,20 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
max
(
0
,
start
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
q_idx
=
bx
*
block_M
+
i
k_idx
=
k
*
block_N
+
j
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
...
@@ -106,8 +106,7 @@ def flashattn_fwd(
...
@@ -106,8 +106,7 @@ def flashattn_fwd(
# NOTE(wt): check_inf is necessary for sliding window attention.
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
...
@@ -124,22 +123,23 @@ def flashattn_fwd(
...
@@ -124,22 +123,23 @@ def flashattn_fwd(
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
return
flash_fwd
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
accum_dtype
=
"float"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
...
@@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
accum_dtype
=
"float"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -191,26 +192,21 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
...
@@ -191,26 +192,21 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
T
.
copy
(
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
)
)
return
flash_bwd_post
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
}
def
flashattn_bwd
(
batch
,
)
heads
,
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
groups
,
window_size
=
None
,
sm_scale
=
None
,
dtype
=
"float16"
):
# None for full attention
seq_len
,
dim
,
groups
,
window_size
=
None
,
sm_scale
=
None
,
dtype
=
"float16"
):
# None for full attention
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
...
@@ -253,44 +249,47 @@ def flashattn_bwd(batch,
...
@@ -253,44 +249,47 @@ def flashattn_bwd(batch,
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
accum_dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
accum_dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
}
T
.
copy
(
K
[
bz
,
bx
//
groups
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
K_shared
)
)
T
.
copy
(
V
[
bz
,
bx
//
groups
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
min
(
loop_ed
=
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
qkT
[
i
,
j
]
=
T
.
if_then_else
(
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
and
by
*
block_M
+
i
<=
k
*
block_N
+
j
and
by
*
block_M
+
i
>
k
*
block_N
+
j
-
window_size
,
qkT
[
i
,
j
],
0
by
*
block_M
+
i
>
k
*
block_N
+
j
-
window_size
,
qkT
[
i
,
j
],
0
)
)
else
:
else
:
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
0
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dst
=
do
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
dst
=
do
)
T
.
clear
(
dsT
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
@@ -299,12 +298,12 @@ def flashattn_bwd(batch,
...
@@ -299,12 +298,12 @@ def flashattn_bwd(batch,
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dq
)
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dq
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
bx
//
groups
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
dk_shared
)
return
flash_bwd
return
flash_bwd
...
@@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"
...
@@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"
dsink_fragment
=
T
.
alloc_fragment
([
block
],
dtype
)
dsink_fragment
=
T
.
alloc_fragment
([
block
],
dtype
)
sink
[
0
]
=
Sinks
[
bx
]
sink
[
0
]
=
Sinks
[
bx
]
T
.
copy
(
lse
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
lse_fragment
)
T
.
copy
(
lse
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
lse_fragment
)
T
.
copy
(
Delta
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
delta_fragment
)
T
.
copy
(
Delta
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
delta_fragment
)
for
i
in
T
.
Parallel
(
block
):
for
i
in
T
.
Parallel
(
block
):
dsink_fragment
[
i
]
=
-
T
.
exp2
(
Sinks
[
bx
]
*
1.44269504
-
dsink_fragment
[
i
]
=
-
T
.
exp2
(
Sinks
[
bx
]
*
1.44269504
-
lse_fragment
[
i
])
*
delta_fragment
[
i
]
lse_fragment
[
i
])
*
delta_fragment
[
i
]
T
.
copy
(
dsink_fragment
,
dsinks
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
])
T
.
copy
(
dsink_fragment
,
dsinks
[
bz
,
bx
,
by
*
block
:(
by
+
1
)
*
block
])
return
flash_bwd_dsink
return
flash_bwd_dsink
class
_attention
(
torch
.
autograd
.
Function
):
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
sinks
,
window_size
,
groups
):
def
forward
(
ctx
,
q
,
k
,
v
,
sinks
,
window_size
,
groups
):
def
maybe_contiguous
(
x
):
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
.
contiguous
()
...
@@ -388,13 +384,14 @@ attention = _attention.apply
...
@@ -388,13 +384,14 @@ attention = _attention.apply
# Adapted and optimized from
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
...
@@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor,
...
@@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
main
(
BATCH
:
int
=
1
,
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
8
,
H
:
int
=
8
,
N_CTX
:
int
=
512
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
64
,
D_HEAD
:
int
=
64
,
groups
:
int
=
2
,
groups
:
int
=
2
,
window_size
:
Optional
[
int
]
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
N_CTX
assert
window_size
<=
N_CTX
flops_per_matmul
=
2.0
*
BATCH
*
H
*
min
(
flops_per_matmul
=
2.0
*
BATCH
*
H
*
min
(
window_size
,
N_CTX
//
2
)
*
N_CTX
*
D_HEAD
# just a rough estimation
window_size
,
N_CTX
//
2
)
*
N_CTX
*
D_HEAD
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
*
0.5
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
*
0.5
total_flops
=
5
*
flops_per_matmul
total_flops
=
5
*
flops_per_matmul
Q
=
(
torch
.
randn
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
())
Q
=
torch
.
randn
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
K
=
torch
.
randn
(
K
=
torch
.
randn
(
BATCH
,
H
//
groups
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
BATCH
,
H
//
groups
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
V
=
torch
.
randn_like
(
K
).
requires_grad_
()
V
=
torch
.
randn_like
(
K
).
requires_grad_
()
sinks
=
torch
.
randn
(
H
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
sinks
=
torch
.
randn
(
H
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
dO
=
torch
.
randn_like
(
Q
)
...
@@ -479,16 +475,11 @@ def main(BATCH: int = 1,
...
@@ -479,16 +475,11 @@ def main(BATCH: int = 1,
"float16"
:
(
1e-2
,
1e-2
),
"float16"
:
(
1e-2
,
1e-2
),
"bfloat16"
:
(
2e-2
,
2e-2
),
"bfloat16"
:
(
2e-2
,
2e-2
),
}[
dtype
]
}[
dtype
]
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'O max err:
{
(
O
-
O_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"O max err:
{
(
O
-
O_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dV max err:
{
(
dV
-
dV_ref
).
abs
().
max
()
}
"
dV
,
dV_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dV max err:
{
(
dV
-
dV_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dK max err:
{
(
dK
-
dK_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dq max err:
{
(
dQ
-
dQ_ref
).
abs
().
max
()
}
"
dK
,
dK_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dK max err:
{
(
dK
-
dK_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dsinks
,
dsinks_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dsinks max err:
{
(
dsinks
-
dsinks_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dq max err:
{
(
dQ
-
dQ_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dsinks
,
dsinks_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dsinks max err:
{
(
dsinks
-
dsinks_ref
).
abs
().
max
()
}
'
print
(
"All checks passed for tilelang kernels.✅"
)
print
(
"All checks passed for tilelang kernels.✅"
)
...
@@ -509,17 +500,12 @@ def main(BATCH: int = 1,
...
@@ -509,17 +500,12 @@ def main(BATCH: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'Number of heads'
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"Number of heads"
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
4096
,
help
=
'Context size'
)
parser
.
add_argument
(
"--n_ctx"
,
type
=
int
,
default
=
4096
,
help
=
"Context size"
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension'
)
parser
.
add_argument
(
"--d_head"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension"
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'Groups'
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"Groups"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
)
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
View file @
29051439
...
@@ -23,9 +23,11 @@ def get_configs():
...
@@ -23,9 +23,11 @@ def get_configs():
rep
=
100
,
rep
=
100
,
)
)
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
def
flashattn
(
batch
,
batch
,
heads
,
heads
,
...
@@ -41,12 +43,11 @@ def flashattn(
...
@@ -41,12 +43,11 @@ def flashattn(
threads
=
256
,
threads
=
256
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
):
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
...
@@ -68,13 +69,12 @@ def flashattn(
...
@@ -68,13 +69,12 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -89,7 +89,7 @@ def flashattn(
...
@@ -89,7 +89,7 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
//
groups
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
...
@@ -112,8 +112,7 @@ def flashattn(
...
@@ -112,8 +112,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -157,25 +156,25 @@ def flashattn(
...
@@ -157,25 +156,25 @@ def flashattn(
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
Q_shared
:
make_swizzled_layout
(
Q_shared
),
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
})
}
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
for
k
in
T
.
Pipelined
(
start
,
start
,
...
@@ -183,32 +182,32 @@ def flashattn(
...
@@ -183,32 +182,32 @@ def flashattn(
num_stages
=
num_stages
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]]):
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
# Following functions are adapted and optimized from
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
batch_size
,
num_keys
,
num_key_value_heads
,
head_dim
=
key
.
shape
...
@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
...
@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
gen_inputs
(
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
H
,
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Sq
,
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Skv
,
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
"cuda"
)
D
,
groups
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
([
B
,
H
//
groups
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
'cuda'
)
return
query
,
key
,
value
,
sinks
return
query
,
key
,
value
,
sinks
...
@@ -277,12 +268,11 @@ def main(
...
@@ -277,12 +268,11 @@ def main(
):
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -310,15 +300,14 @@ def main(
...
@@ -310,15 +300,14 @@ def main(
block_N
=
block_N
,
block_N
=
block_N
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
threads
=
threads
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
groups
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
)
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
# Benchmark tilelang
# Benchmark tilelang
...
@@ -329,20 +318,14 @@ def main(
...
@@ -329,20 +318,14 @@ def main(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
64
,
help
=
"heads"
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
2048
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
8
,
help
=
"groups"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune configs"
)
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
groups
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
args
.
dtype
,
args
.
tune
)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
View file @
29051439
...
@@ -20,9 +20,11 @@ def get_bwd_configs():
...
@@ -20,9 +20,11 @@ def get_bwd_configs():
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
def
flashattn_fwd
(
batch
,
batch
,
heads
,
heads
,
...
@@ -34,13 +36,13 @@ def flashattn_fwd(
...
@@ -34,13 +36,13 @@ def flashattn_fwd(
block_N
=
64
,
block_N
=
64
,
num_stages
=
1
,
num_stages
=
1
,
threads
=
128
,
threads
=
128
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -70,7 +72,7 @@ def flashattn_fwd(
...
@@ -70,7 +72,7 @@ def flashattn_fwd(
sinks
=
T
.
alloc_fragment
([
heads
],
dtype
)
sinks
=
T
.
alloc_fragment
([
heads
],
dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -78,22 +80,20 @@ def flashattn_fwd(
...
@@ -78,22 +80,20 @@ def flashattn_fwd(
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
end
=
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
start
=
T
.
max
(
0
,
start
=
T
.
max
(
0
,
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
(
bx
*
block_M
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
q_idx
=
bx
*
block_M
+
i
k_idx
=
k
*
block_N
+
j
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
...
@@ -103,8 +103,7 @@ def flashattn_fwd(
...
@@ -103,8 +103,7 @@ def flashattn_fwd(
# NOTE(wt): check_inf is necessary for sliding window attention.
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
...
@@ -121,22 +120,23 @@ def flashattn_fwd(
...
@@ -121,22 +120,23 @@ def flashattn_fwd(
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
return
flash_fwd
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
accum_dtype
=
"float"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
...
@@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
,
dtype
:
str
=
"float16"
):
accum_dtype
=
"float"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -188,16 +189,18 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
...
@@ -188,16 +189,18 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
T
.
copy
(
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
)
)
return
flash_bwd_post
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
}
)
def
flashattn_bwd
(
def
flashattn_bwd
(
batch
,
batch
,
heads
,
heads
,
...
@@ -207,11 +210,10 @@ def flashattn_bwd(
...
@@ -207,11 +210,10 @@ def flashattn_bwd(
sm_scale
=
None
,
sm_scale
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
):
):
block_M
,
block_N
,
num_stages
,
threads
=
get_bwd_configs
()
block_M
,
block_N
,
num_stages
,
threads
=
get_bwd_configs
()
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
...
@@ -254,43 +256,46 @@ def flashattn_bwd(
...
@@ -254,43 +256,46 @@ def flashattn_bwd(
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
}
T
.
copy
(
K
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
K_shared
)
)
T
.
copy
(
V
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
loop_ed
=
T
.
min
(
loop_ed
=
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
((
by
+
1
)
*
block_M
+
window_size
,
block_N
),
T
.
ceildiv
(
seq_len
,
block_N
))
seq_len
,
block_N
))
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
if
window_size
is
not
None
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
qkT
[
i
,
j
]
=
T
.
if_then_else
(
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
and
by
*
block_M
+
i
<=
k
*
block_N
+
j
and
by
*
block_M
+
i
>
k
*
block_N
+
j
-
window_size
,
qkT
[
i
,
j
],
0
by
*
block_M
+
i
>
k
*
block_N
+
j
-
window_size
,
qkT
[
i
,
j
],
0
)
)
else
:
else
:
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
0
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dst
=
do
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
dst
=
do
)
T
.
clear
(
dsT
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
B
=
do
,
C
=
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
qkT_cast
,
B
=
do
,
C
=
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
@@ -299,12 +304,12 @@ def flashattn_bwd(
...
@@ -299,12 +304,12 @@ def flashattn_bwd(
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dq
)
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
dq
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dv_shared
,
dV
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
return
flash_bwd
return
flash_bwd
...
@@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"
...
@@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"
dsink_fragment
=
T
.
alloc_fragment
([
block
],
accum_dtype
)
dsink_fragment
=
T
.
alloc_fragment
([
block
],
accum_dtype
)
sink
[
0
]
=
Sinks
[
bx
]
sink
[
0
]
=
Sinks
[
bx
]
T
.
copy
(
lse
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
lse_fragment
)
T
.
copy
(
lse
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
lse_fragment
)
T
.
copy
(
Delta
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
delta_fragment
)
T
.
copy
(
Delta
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
],
delta_fragment
)
for
i
in
T
.
Parallel
(
block
):
for
i
in
T
.
Parallel
(
block
):
dsink_fragment
[
i
]
=
-
T
.
exp2
(
Sinks
[
bx
]
*
1.44269504
-
dsink_fragment
[
i
]
=
-
T
.
exp2
(
Sinks
[
bx
]
*
1.44269504
-
lse_fragment
[
i
])
*
delta_fragment
[
i
]
lse_fragment
[
i
])
*
delta_fragment
[
i
]
T
.
copy
(
dsink_fragment
,
dsinks
[
bz
,
bx
,
by
*
block
:
(
by
+
1
)
*
block
])
T
.
copy
(
dsink_fragment
,
dsinks
[
bz
,
bx
,
by
*
block
:(
by
+
1
)
*
block
])
return
flash_bwd_dsink
return
flash_bwd_dsink
class
_attention
(
torch
.
autograd
.
Function
):
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
sinks
,
window_size
):
def
forward
(
ctx
,
q
,
k
,
v
,
sinks
,
window_size
):
BATCH
,
H
,
N_CTX
,
D_HEAD
=
q
.
shape
BATCH
,
H
,
N_CTX
,
D_HEAD
=
q
.
shape
...
@@ -383,15 +386,15 @@ attention = _attention.apply
...
@@ -383,15 +386,15 @@ attention = _attention.apply
# Adapted and optimized from
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
3
)
# align with the original function's interface
3
)
# align with the original function's interface
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
...
@@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor,
...
@@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
main
(
BATCH
:
int
=
1
,
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
1
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
H
:
int
=
1
,
N_CTX
:
int
=
512
,
D_HEAD
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
N_CTX
assert
window_size
<=
N_CTX
flops_per_matmul
=
2.0
*
BATCH
*
H
*
min
(
flops_per_matmul
=
2.0
*
BATCH
*
H
*
min
(
window_size
,
N_CTX
//
2
)
*
N_CTX
*
D_HEAD
# just a rough estimation
window_size
,
N_CTX
//
2
)
*
N_CTX
*
D_HEAD
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
*
0.5
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
*
0.5
total_flops
=
5
*
flops_per_matmul
total_flops
=
5
*
flops_per_matmul
Q
=
(
torch
.
randn
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
)
Q
=
torch
.
randn
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch_dtype
,
device
=
"cuda"
).
requires_grad_
()
K
=
torch
.
randn_like
(
Q
).
requires_grad_
()
K
=
torch
.
randn_like
(
Q
).
requires_grad_
()
V
=
torch
.
randn_like
(
Q
).
requires_grad_
()
V
=
torch
.
randn_like
(
Q
).
requires_grad_
()
sinks
=
torch
.
randn
(
H
,
dtype
=
torch_dtype
,
device
=
Q
.
device
).
requires_grad_
()
sinks
=
torch
.
randn
(
H
,
dtype
=
torch_dtype
,
device
=
Q
.
device
).
requires_grad_
()
...
@@ -473,16 +469,11 @@ def main(BATCH: int = 1,
...
@@ -473,16 +469,11 @@ def main(BATCH: int = 1,
"float16"
:
(
1e-2
,
1e-2
),
"float16"
:
(
1e-2
,
1e-2
),
"bfloat16"
:
(
2e-2
,
2e-2
),
"bfloat16"
:
(
2e-2
,
2e-2
),
}[
dtype
]
}[
dtype
]
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'O max err:
{
(
O
-
O_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"O max err:
{
(
O
-
O_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dV max err:
{
(
dV
-
dV_ref
).
abs
().
max
()
}
"
dV
,
dV_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dV max err:
{
(
dV
-
dV_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dK max err:
{
(
dK
-
dK_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dq max err:
{
(
dQ
-
dQ_ref
).
abs
().
max
()
}
"
dK
,
dK_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dK max err:
{
(
dK
-
dK_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dsinks
,
dsinks_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
"dsinks max err:
{
(
dsinks
-
dsinks_ref
).
abs
().
max
()
}
"
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dq max err:
{
(
dQ
-
dQ_ref
).
abs
().
max
()
}
'
assert
torch
.
allclose
(
dsinks
,
dsinks_ref
,
rtol
=
rtol
,
atol
=
atol
),
f
'dsinks max err:
{
(
dsinks
-
dsinks_ref
).
abs
().
max
()
}
'
print
(
"All checks passed for tilelang kernels.✅"
)
print
(
"All checks passed for tilelang kernels.✅"
)
...
@@ -503,16 +494,11 @@ def main(BATCH: int = 1,
...
@@ -503,16 +494,11 @@ def main(BATCH: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'Number of heads'
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"Number of heads"
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
4096
,
help
=
'Context size'
)
parser
.
add_argument
(
"--n_ctx"
,
type
=
int
,
default
=
4096
,
help
=
"Context size"
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension'
)
parser
.
add_argument
(
"--d_head"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
window_size
,
args
.
dtype
)
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
window_size
,
args
.
dtype
)
Prev
1
2
3
4
5
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment