Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
887 additions
and
1121 deletions
+887
-1121
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+53
-100
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+52
-95
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+1
-2
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+4
-16
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+30
-40
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+32
-33
examples/cast/example_per_token_cast_to_fp8.py
examples/cast/example_per_token_cast_to_fp8.py
+15
-20
examples/cast/example_triton_cast_to_fp8.py
examples/cast/example_triton_cast_to_fp8.py
+1
-3
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+1
-2
examples/conftest.py
examples/conftest.py
+2
-5
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+25
-43
examples/convolution/example_convolution_autotune.py
examples/convolution/example_convolution_autotune.py
+48
-91
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+21
-24
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
...les/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
+84
-110
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
+91
-74
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
+91
-74
examples/deepseek_mla/benchmark_mla.py
examples/deepseek_mla/benchmark_mla.py
+101
-97
examples/deepseek_mla/example_mla_decode.py
examples/deepseek_mla/example_mla_decode.py
+86
-103
examples/deepseek_mla/example_mla_decode_paged.py
examples/deepseek_mla/example_mla_decode_paged.py
+104
-130
examples/deepseek_mla/example_mla_decode_persistent.py
examples/deepseek_mla/example_mla_decode_persistent.py
+45
-59
No files found.
Too many changes to show.
To preserve performance only
313 of 313+
files are displayed.
Plain diff
Email patch
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
View file @
667632cc
...
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -79,16 +75,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
...
...
@@ -119,23 +110,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
...
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -369,34 +331,29 @@ def main(batch=64,
dtype
=
torch
.
float16
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
...
@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
...
...
@@ -423,8 +379,7 @@ def main(batch=64,
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -466,15 +421,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
View file @
667632cc
...
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -77,16 +73,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
...
...
@@ -117,23 +108,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
...
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
sparse_ratio
=
sparse_ratio
...
...
@@ -363,14 +325,13 @@ def main(batch=64,
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
...
...
@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
...
...
@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
...
...
@@ -404,8 +364,7 @@ def main(batch=64,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -448,15 +407,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
View file @
667632cc
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
667632cc
...
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
...
...
examples/blocksparse_gemm/example_blocksparse_gemm.py
View file @
667632cc
...
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
@@ -41,17 +40,19 @@ def get_configs():
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
return
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
...
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
...
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def
main
():
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
...
...
@@ -147,8 +138,7 @@ def main():
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
...
...
@@ -163,10 +153,10 @@ def main():
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
)
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
667632cc
...
...
@@ -5,8 +5,8 @@ from typing import Tuple
from
tilelang.utils.tensor
import
torch_assert_close
# support bfloat16, float, float16
dtype
=
"
bfloat16
"
accum_dtype
=
"
float
"
dtype
=
T
.
bfloat16
accum_dtype
=
T
.
float
32
@
tilelang
.
jit
(
out_idx
=
[
2
,
3
])
...
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max
=
448.0
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
(
(
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
(
(
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
((
BG
,),
T
.
int32
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
T
.
float8_e4m3fn
),
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row_g_id
=
by
bg
=
bz
...
...
@@ -28,39 +30,35 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
accum_dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"
float8_e4m3
"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"
int32
"
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
T
.
float8_e4m3
fn
)
row_offset
=
T
.
alloc_fragment
((
1
,),
T
.
int32
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
...
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
...
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
\
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
'cuda'
,
dtype
=
torch
.
float
))
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
...
...
@@ -164,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
if
batch_sizes
is
None
:
batch_sizes
=
[
2048
,
6144
]
if
dtype
==
"
float
"
:
if
dtype
==
T
.
float
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
elif
dtype
==
"
float16
"
:
elif
dtype
==
T
.
float16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
elif
dtype
==
"
bfloat16
"
:
elif
dtype
==
T
.
bfloat16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
...
...
examples/cast/example_per_token_cast_to_fp8.py
View file @
667632cc
...
...
@@ -7,14 +7,15 @@ from tilelang.utils.tensor import torch_assert_close
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
])
def
per_token_cast_to_fp8
(
M
,
N
,
blk_m
):
dtype
=
"
float
"
dtype
=
T
.
float
group_size
=
128
fp8_min
=
-
448.0
fp8_max
=
448.0
@
T
.
prim_func
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)):
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
T
.
float8_e4m3fn
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
row
=
bx
row_g_id
=
by
...
...
@@ -22,18 +23,15 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
copy
(
X
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
T
.
float8_e4m3fn
)
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
T
.
copy
(
X
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
...
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
per_token_cast
...
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
def
run_triton
():
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
return
x_fp8_triton_
,
x_amax_triton_
x_fp8_triton
,
x_amax_triton
=
run_triton
()
...
...
examples/cast/example_triton_cast_to_fp8.py
View file @
667632cc
...
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible by `group_size`
{
group_size
}
"
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
...
...
examples/cast/test_example_cast.py
View file @
667632cc
...
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
...
...
examples/conftest.py
View file @
667632cc
...
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"error"
,
}
if
(
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
terminalreporter
.
write_sep
(
"!"
,
(
f
"Error: No tests were collected. "
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
examples/convolution/example_convolution.py
View file @
667632cc
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -26,38 +25,21 @@ def ref_program(stride, padding, dilation):
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
is_hopper
=
check_hopper
()
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -82,10 +66,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -97,15 +79,15 @@ def convolution(N,
def
main
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--n
'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'
--c
'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'
--w
'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'
--f
'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'
--k
'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'
--s
'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'
--p
'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"
--n
"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"
--c
"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"
--w
"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"
--f
"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"
--k
"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"
--s
"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"
--p
"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
args
=
parser
.
parse_args
(
argv
)
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
...
...
examples/convolution/example_convolution_autotune.py
View file @
667632cc
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -40,7 +39,8 @@ def get_configs():
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -50,7 +50,8 @@ def get_configs():
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -64,69 +65,32 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
is_hopper
=
check_hopper
()
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -136,9 +100,11 @@ def convolution(N,
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
if
is_hopper
:
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -150,10 +116,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -166,17 +130,19 @@ def convolution(N,
return
main
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
):
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
,
):
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
ref_prog
=
ref_program
(
S
,
P
,
D
)
...
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'--c'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'--w'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'--f'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'--k'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'--s'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'--d'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'--p'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"--c"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"--w"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"--s"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"--d"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
667632cc
...
...
@@ -20,11 +20,11 @@ def tl_gemm(
accum_dtype
,
):
assert
in_dtype
in
[
"
float8_e4m3
"
,
T
.
float8_e4m3
fn
,
],
"Currently only float8_e4m3 is supported"
assert
out_dtype
in
[
"
bfloat16
"
,
"
float32
"
,
T
.
bfloat16
,
T
.
float32
,
],
"Currently only float16 and float32 are supported"
group_size
=
128
...
...
@@ -41,18 +41,17 @@ def tl_gemm(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"
float32
"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"
float32
"
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
T
.
float32
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
T
.
float32
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
Scale_C_shared
=
T
.
alloc_shared
((
block_M
),
"
float32
"
)
Scale_C_shared
=
T
.
alloc_shared
((
block_M
),
T
.
float32
)
C_local
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
C_local_accum
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
...
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
...
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc
.
zero_
()
for
k
in
range
(
ceildiv
(
K
,
128
)):
c
=
torch
.
_scaled_mm
(
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
,
)
c_acc
+=
c
.
to
(
torch
.
float32
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
return
C
...
...
@@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def
main
():
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
128
,
"
float8_e4m3
"
,
"
bfloat16
"
,
"
float32
"
)
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
128
,
T
.
float8_e4m3
fn
,
T
.
bfloat16
,
T
.
float32
)
if
__name__
==
"__main__"
:
for
dtype
in
[
"
float8_e4m3
"
]:
for
out_dtype
in
[
"
bfloat16
"
,
"
float32
"
]:
for
dtype
in
[
T
.
float8_e4m3
fn
]:
for
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
]:
for
block_N
in
[
16
,
32
,
64
,
128
]:
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
block_N
,
dtype
,
out_dtype
,
"
float32
"
)
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
block_N
,
dtype
,
out_dtype
,
T
.
float32
)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
View file @
667632cc
...
...
@@ -8,6 +8,7 @@ import argparse
def
get_configs
():
import
itertools
BLOCK_N
=
[
16
,
32
,
64
,
128
]
BLOCK_H
=
[
16
,
32
,
64
,
128
]
num_split
=
[
1
,
2
,
4
,
8
,
16
,
32
]
...
...
@@ -15,43 +16,39 @@ def get_configs():
_configs
=
list
(
itertools
.
product
(
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
))
return
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"threads"
:
c
[
3
],
}
for
c
in
_configs
]
return
[
{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"threads"
:
c
[
3
],
}
for
c
in
_configs
]
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashmla_decode
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
},
)
def
flashmla_decode
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
threads
=
threads
)
as
(
bx
,
by
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
...
...
@@ -70,27 +67,24 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
0
):
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -105,20 +99,18 @@ def flashmla_decode(batch,
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
Q_pe_local
=
T
.
alloc_fragment
([
block_H
,
pe_dim
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -134,8 +126,8 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -148,15 +140,12 @@ def flashmla_decode(batch,
T
.
copy
(
K_pe
[
bx
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -172,14 +161,14 @@ def flashmla_decode(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
...
@@ -189,9 +178,11 @@ def flashmla_decode(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -214,26 +205,26 @@ def flashmla_decode(batch,
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -258,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
'
--autotune
'
,
action
=
'
store_true
'
,
help
=
'
auto tune
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
parser
.
add_argument
(
"
--autotune
"
,
action
=
"
store_true
"
,
help
=
"
auto tune
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
enable_autotune
=
args
.
autotune
...
...
@@ -310,17 +294,7 @@ if __name__ == "__main__":
if
enable_autotune
:
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
else
:
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
input_tensors
=
profiler
.
_get_inputs
()
tilelang_output
=
kernel
(
*
input_tensors
)
...
...
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
View file @
667632cc
...
...
@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -94,8 +93,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -141,9 +139,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -309,24 +305,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
...
...
@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -429,26 +422,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -470,26 +459,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
View file @
667632cc
...
...
@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -91,8 +90,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -138,9 +136,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -306,24 +302,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
...
...
@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -426,26 +419,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
64
,
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
64
,
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -467,26 +456,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/benchmark_mla.py
View file @
667632cc
...
...
@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
blocked_v
=
blocked_k
[...,
:
dv
]
...
...
@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# pip install flashinfer-python
import
flashinfer
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
kv_indptr
=
[
0
]
kv_indices
=
[]
...
...
@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
mla_wrapper
.
plan
(
q_indptr
,
kv_indptr
,
...
...
@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
)
def
flashinfer
():
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
out_flash
,
lse_flash
=
flashinfer
()
...
...
@@ -177,8 +168,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -224,9 +214,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -393,24 +381,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@
torch
.
inference_mode
()
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dpe
=
d
-
dv
num_kv_splits
=
1
...
...
@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
)
def
flash_mla_tilelang
():
out
=
kernel
(
...
...
@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -558,26 +538,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -599,26 +575,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/example_mla_decode.py
View file @
667632cc
...
...
@@ -8,25 +8,26 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
}
,
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -44,36 +45,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
T
.
copy
(
KV
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
KV
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -88,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
...
...
@@ -119,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -137,17 +131,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -164,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
hid
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
...
@@ -183,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -208,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -252,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -294,10 +278,9 @@ def main(
BLOCK_N
=
64
BLOCK_H
=
min
(
64
,
heads
//
kv_heads
)
num_split
=
1
softmax_scale
=
(
dim
+
pe_dim
)
**-
0.5
softmax_scale
=
(
dim
+
pe_dim
)
**
-
0.5
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
1e-4
,
atol
=
1e-4
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -307,12 +290,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
132
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
132
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/deepseek_mla/example_mla_decode_paged.py
View file @
667632cc
...
...
@@ -8,25 +8,17 @@ import math
@
tilelang
.
jit
(
out_idx
=
[
8
],
pass_configs
=
{
out_idx
=
[
8
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
},
)
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
if
softmax_scale
is
None
:
softmax_scale
=
(
dv
+
dpe
)
**-
0.5
softmax_scale
=
(
dv
+
dpe
)
**
-
0.5
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
h_q
//
h_kv
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
h_kv
==
1
,
"h_kv must be 1"
...
...
@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
@
T
.
macro
def
flash_mla_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"
int32
"
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
T
.
int32
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
threads
=
256
)
as
(
bx
,
by
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
...
...
@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -73,26 +67,20 @@ def mla_decode_tilelang(batch,
loop_range
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
for
kr
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
k
=
loop_range
-
1
-
kr
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
if
kr
==
0
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -107,21 +95,20 @@ def mla_decode_tilelang(batch,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dv
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_mla_split_kv_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"
int32
"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
dpe
],
dtype
)
...
...
@@ -139,13 +126,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -153,29 +142,23 @@ def mla_decode_tilelang(batch,
total_blocks
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
blocks_per_split
=
T
.
floordiv
(
total_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
total_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
start
=
(
blocks_per_split
*
bz
+
T
.
min
(
bz
,
remaining_blocks
))
*
block_N
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -192,15 +175,15 @@ def mla_decode_tilelang(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
with
T
.
Kernel
(
h_q
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dv
],
dtype
)
...
...
@@ -210,9 +193,11 @@ def mla_decode_tilelang(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -235,31 +220,30 @@ def mla_decode_tilelang(batch,
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"
int32
"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
Output_partial
)
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"
int32
"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
flash_mla_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
Output
)
...
...
@@ -280,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
...
...
@@ -291,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
...
...
@@ -321,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return
out_torch
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dpe
=
d
-
dv
num_kv_splits
=
1
...
...
@@ -337,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
,
softmax_scale
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
def
flash_mla_tilelang
():
...
...
@@ -356,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash
=
flash_mla_tilelang
()
t
=
do_bench
(
flash_mla_tilelang
)
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_flash
,
out_ref
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All close"
)
return
out_flash
,
t
...
...
@@ -365,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--h_q
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--h_kv
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--cache_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv cache context length
'
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
576
,
help
=
'
query/key head dim, d = dv + dpe
'
)
parser
.
add_argument
(
'
--dv
'
,
type
=
int
,
default
=
512
,
help
=
'
value head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--h_q
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--h_kv
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--cache_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv cache context length
"
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
576
,
help
=
"
query/key head dim, d = dv + dpe
"
)
parser
.
add_argument
(
"
--dv
"
,
type
=
int
,
default
=
512
,
help
=
"
value head dim
"
)
args
=
parser
.
parse_args
()
b
,
h_q
,
h_kv
,
cache_seqlen
,
d
,
dv
=
args
.
batch
,
args
.
h_q
,
args
.
h_kv
,
args
.
cache_seqlen
,
args
.
d
,
args
.
dv
...
...
@@ -379,9 +356,7 @@ if __name__ == "__main__":
s_q
=
1
# for decode, s_q = 1
block_size
=
64
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
dtype
=
torch
.
int32
,
device
=
device
)
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
dtype
=
torch
.
int32
,
device
=
device
)
dpe
=
d
-
dv
causal
=
True
...
...
@@ -393,12 +368,11 @@ if __name__ == "__main__":
total_flops
=
s_q
*
total_seqlens
*
h_q
*
d
*
2
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
,
dtype
=
dtype
,
device
=
device
)
out_flash
,
latency
=
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_flash
,
latency
=
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
examples/deepseek_mla/example_mla_decode_persistent.py
View file @
667632cc
...
...
@@ -9,13 +9,15 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"
float16
"
accum_dtype
=
"
float
"
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
...
...
@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_split_persistent
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
sm_num
,
threads
=
256
)
as
(
block_id
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
use_swizzle
(
10
)
total_tiles
=
batch
*
(
heads
//
min
(
block_H
,
kv_group_num
))
*
num_split
...
...
@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
if
bid
<
batch
and
hid
*
VALID_BLOCK_H
<
heads
and
sid
<
num_split
:
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -83,24 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -115,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
# T.copy(acc_o, O_shared)
T
.
copy
(
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
T
.
copy
(
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
T
.
sync_grid
()
waves
=
T
.
ceildiv
(
heads
*
batch
,
sm_num
)
...
...
@@ -165,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
qk_flops
=
2
*
batch
*
heads
*
kv_ctx
*
(
dim
+
pe_dim
)
...
...
Prev
1
2
3
4
5
6
7
8
9
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment