Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
887 additions
and
1121 deletions
+887
-1121
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+53
-100
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+52
-95
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+1
-2
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+4
-16
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+30
-40
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+32
-33
examples/cast/example_per_token_cast_to_fp8.py
examples/cast/example_per_token_cast_to_fp8.py
+15
-20
examples/cast/example_triton_cast_to_fp8.py
examples/cast/example_triton_cast_to_fp8.py
+1
-3
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+1
-2
examples/conftest.py
examples/conftest.py
+2
-5
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+25
-43
examples/convolution/example_convolution_autotune.py
examples/convolution/example_convolution_autotune.py
+48
-91
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+21
-24
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
...les/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
+84
-110
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
+91
-74
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
+91
-74
examples/deepseek_mla/benchmark_mla.py
examples/deepseek_mla/benchmark_mla.py
+101
-97
examples/deepseek_mla/example_mla_decode.py
examples/deepseek_mla/example_mla_decode.py
+86
-103
examples/deepseek_mla/example_mla_decode_paged.py
examples/deepseek_mla/example_mla_decode_paged.py
+104
-130
examples/deepseek_mla/example_mla_decode_persistent.py
examples/deepseek_mla/example_mla_decode_persistent.py
+45
-59
No files found.
Too many changes to show.
To preserve performance only
313 of 313+
files are displayed.
Plain diff
Email patch
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
View file @
667632cc
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_split_kernel
(
def
_split_kernel
(
...
@@ -79,16 +75,11 @@ def _split_kernel(
...
@@ -79,16 +75,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
...
@@ -119,23 +110,18 @@ def _split_kernel(
...
@@ -119,23 +110,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
o_partial_ptr
+=
(
head_idx_q
+
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_merge_kernel
(
def
_merge_kernel
(
...
@@ -163,18 +149,15 @@ def _merge_kernel(
...
@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
num_sm
=
64
# num_sm = self.num_sm
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
num_splits
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
# Assign mask values based on block_indices
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
for
idx
in
valid_indices
:
if
idx
>=
0
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
return
output
return
output
def
main
(
batch
=
64
,
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
block_size
=
block_size
block_size
=
block_size
...
@@ -369,34 +331,29 @@ def main(batch=64,
...
@@ -369,34 +331,29 @@ def main(batch=64,
dtype
=
torch
.
float16
dtype
=
torch
.
float16
block_H
=
64
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# Assign valid indices while ensuring no duplicates within each batch-group
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
@@ -408,8 +365,7 @@ def main(batch=64,
...
@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
Q
,
...
@@ -423,8 +379,7 @@ def main(batch=64,
...
@@ -423,8 +379,7 @@ def main(batch=64,
)
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
print
(
"Passed the ref test!"
)
# Measure performance
# Measure performance
...
@@ -466,15 +421,13 @@ def main(batch=64,
...
@@ -466,15 +421,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
View file @
667632cc
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_split_kernel
(
def
_split_kernel
(
...
@@ -77,16 +73,11 @@ def _split_kernel(
...
@@ -77,16 +73,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
...
@@ -117,23 +108,18 @@ def _split_kernel(
...
@@ -117,23 +108,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
o_partial_ptr
+=
(
head_idx_q
+
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_merge_kernel
(
def
_merge_kernel
(
...
@@ -161,18 +147,15 @@ def _merge_kernel(
...
@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
num_sm
=
64
# num_sm = self.num_sm
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
num_splits
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
# Assign mask values
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
return
output
return
output
def
main
(
batch
=
64
,
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
block_size
=
block_size
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
...
@@ -363,14 +325,13 @@ def main(batch=64,
...
@@ -363,14 +325,13 @@ def main(batch=64,
dtype
=
torch
.
float16
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
...
@@ -379,7 +340,7 @@ def main(batch=64,
...
@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
...
@@ -387,11 +348,10 @@ def main(batch=64,
...
@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
Q
,
...
@@ -404,8 +364,7 @@ def main(batch=64,
...
@@ -404,8 +364,7 @@ def main(batch=64,
)
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
print
(
"Passed the ref test!"
)
# Measure performance
# Measure performance
...
@@ -448,15 +407,13 @@ def main(batch=64,
...
@@ -448,15 +407,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
View file @
667632cc
import
math
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
is_causal_or_local
,
max_splits
):
"""
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
667632cc
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
heads
=
8
,
)
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
heads
=
16
,
)
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/blocksparse_gemm/example_blocksparse_gemm.py
View file @
667632cc
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
@@ -41,17 +40,19 @@ def get_configs():
...
@@ -41,17 +40,19 @@ def get_configs():
thread_num
=
[
128
,
256
]
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[{
return
[
"block_M"
:
c
[
0
],
{
"block_N"
:
c
[
1
],
"block_M"
:
c
[
0
],
"block_K"
:
c
[
2
],
"block_N"
:
c
[
1
],
"num_stages"
:
c
[
3
],
"block_K"
:
c
[
2
],
"thread_num"
:
c
[
4
],
"num_stages"
:
c
[
3
],
"enable_rasteration"
:
c
[
5
],
"thread_num"
:
c
[
4
],
}
for
c
in
_configs
]
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
].
to
(
torch
.
float32
)
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
return
ref_c
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return
input_tensors
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
def
blocksparse_matmul
(
N
,
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
K
,
):
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
@
T
.
prim_func
def
block_sparse_matmul
(
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def
main
():
def
main
():
# Initialize input matrices A and B on the GPU with half precision
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
...
@@ -147,8 +138,7 @@ def main():
...
@@ -147,8 +138,7 @@ def main():
best_config
=
kernel
.
config
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
...
@@ -163,10 +153,10 @@ def main():
...
@@ -163,10 +153,10 @@ def main():
block_K
=
DEFAULT_BLOCK_K
,
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
)
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
667632cc
...
@@ -5,8 +5,8 @@ from typing import Tuple
...
@@ -5,8 +5,8 @@ from typing import Tuple
from
tilelang.utils.tensor
import
torch_assert_close
from
tilelang.utils.tensor
import
torch_assert_close
# support bfloat16, float, float16
# support bfloat16, float, float16
dtype
=
"
bfloat16
"
dtype
=
T
.
bfloat16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
@
tilelang
.
jit
(
out_idx
=
[
2
,
3
])
@
tilelang
.
jit
(
out_idx
=
[
2
,
3
])
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max
=
448.0
fp8_max
=
448.0
@
T
.
prim_func
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
(
def
group_per_split_token_cast
(
(
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
(
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
)):
batch_sizes
:
T
.
Tensor
((
BG
,),
T
.
int32
),
with
T
.
Kernel
(
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
T
.
float8_e4m3fn
),
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row
=
bx
row_g_id
=
by
row_g_id
=
by
bg
=
bz
bg
=
bz
...
@@ -28,39 +30,35 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
...
@@ -28,39 +30,35 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
accum_dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
accum_dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"
float8_e4m3
"
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
T
.
float8_e4m3
fn
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"
int32
"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
T
.
int32
)
T
.
annotate_layout
({
T
.
annotate_layout
(
y_local
:
{
T
.
Fragment
(
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
y_local
.
shape
,
}
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
)
})
row_offset
[
0
]
=
0
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
return
group_per_split_token_cast
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
\
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
'cuda'
,
dtype
=
torch
.
float
))
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
...
@@ -164,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
...
@@ -164,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
if
batch_sizes
is
None
:
if
batch_sizes
is
None
:
batch_sizes
=
[
2048
,
6144
]
batch_sizes
=
[
2048
,
6144
]
if
dtype
==
"
float
"
:
if
dtype
==
T
.
float
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
elif
dtype
==
"
float16
"
:
elif
dtype
==
T
.
float16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
elif
dtype
==
"
bfloat16
"
:
elif
dtype
==
T
.
bfloat16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
...
...
examples/cast/example_per_token_cast_to_fp8.py
View file @
667632cc
...
@@ -7,14 +7,15 @@ from tilelang.utils.tensor import torch_assert_close
...
@@ -7,14 +7,15 @@ from tilelang.utils.tensor import torch_assert_close
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
])
def
per_token_cast_to_fp8
(
M
,
N
,
blk_m
):
def
per_token_cast_to_fp8
(
M
,
N
,
blk_m
):
dtype
=
"
float
"
dtype
=
T
.
float
group_size
=
128
group_size
=
128
fp8_min
=
-
448.0
fp8_min
=
-
448.0
fp8_max
=
448.0
fp8_max
=
448.0
@
T
.
prim_func
@
T
.
prim_func
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
def
per_token_cast
(
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)):
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
T
.
float8_e4m3fn
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
row
=
bx
row
=
bx
row_g_id
=
by
row_g_id
=
by
...
@@ -22,18 +23,15 @@ def per_token_cast_to_fp8(M, N, blk_m):
...
@@ -22,18 +23,15 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
T
.
float8_e4m3fn
)
T
.
annotate_layout
({
T
.
annotate_layout
(
y_local
:
{
T
.
Fragment
(
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
y_local
.
shape
,
}
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
)
})
T
.
copy
(
X
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
copy
(
X
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
return
per_token_cast
return
per_token_cast
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
def
run_triton
():
def
run_triton
():
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
return
x_fp8_triton_
,
x_amax_triton_
return
x_fp8_triton_
,
x_amax_triton_
x_fp8_triton
,
x_amax_triton
=
run_triton
()
x_fp8_triton
,
x_amax_triton
=
run_triton
()
...
...
examples/cast/example_triton_cast_to_fp8.py
View file @
667632cc
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
scaling factor for quantization.
"""
"""
assert
(
x
.
shape
[
-
1
]
%
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible by `group_size`
{
group_size
}
"
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
finfo
=
torch
.
finfo
(
dtype
)
...
...
examples/cast/test_example_cast.py
View file @
667632cc
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
def
test_example_per_token_cast_to_fp8
():
...
...
examples/conftest.py
View file @
667632cc
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"warnings"
,
"error"
,
"error"
,
}
}
if
(
sum
(
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
terminalreporter
.
write_sep
(
terminalreporter
.
write_sep
(
"!"
,
"!"
,
(
f
"Error: No tests were collected. "
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
examples/convolution/example_convolution.py
View file @
667632cc
...
@@ -14,7 +14,6 @@ def check_hopper():
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
@@ -26,38 +25,21 @@ def ref_program(stride, padding, dilation):
...
@@ -26,38 +25,21 @@ def ref_program(stride, padding, dilation):
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
is_hopper
=
check_hopper
()
is_hopper
=
check_hopper
()
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -66,11 +48,13 @@ def convolution(N,
...
@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
})
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -82,10 +66,8 @@ def convolution(N,
...
@@ -82,10 +66,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
@@ -97,15 +79,15 @@ def convolution(N,
...
@@ -97,15 +79,15 @@ def convolution(N,
def
main
(
argv
=
None
):
def
main
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--n
'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
"
--n
"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
'
--c
'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
"
--c
"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
'
--w
'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
"
--w
"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
'
--f
'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
"
--f
"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
'
--k
'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
"
--k
"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
'
--s
'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
"
--s
"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
'
--p
'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"
--p
"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
args
=
parser
.
parse_args
(
argv
)
args
=
parser
.
parse_args
(
argv
)
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
...
...
examples/convolution/example_convolution_autotune.py
View file @
667632cc
...
@@ -14,7 +14,6 @@ def check_hopper():
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
@@ -40,7 +39,8 @@ def get_configs():
...
@@ -40,7 +39,8 @@ def get_configs():
num_stages
,
num_stages
,
thread_num
,
thread_num
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -50,7 +50,8 @@ def get_configs():
...
@@ -50,7 +50,8 @@ def get_configs():
"num_stages"
:
c
[
3
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
...
@@ -64,69 +65,32 @@ def get_heuristic_config() -> dict:
...
@@ -64,69 +65,32 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
if
sm_version
in
{
80
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
elif
sm_version
in
{
90
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
else
:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
def
convolution
(
C
,
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
H
,
):
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
is_hopper
=
check_hopper
()
is_hopper
=
check_hopper
()
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -136,9 +100,11 @@ def convolution(N,
...
@@ -136,9 +100,11 @@ def convolution(N,
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
if
is_hopper
:
if
is_hopper
:
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
})
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -150,10 +116,8 @@ def convolution(N,
...
@@ -150,10 +116,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
@@ -166,17 +130,19 @@ def convolution(N,
...
@@ -166,17 +130,19 @@ def convolution(N,
return
main
return
main
def
main
(
n
:
int
=
128
,
def
main
(
c
:
int
=
128
,
n
:
int
=
128
,
h
:
int
=
64
,
c
:
int
=
128
,
w
:
int
=
64
,
h
:
int
=
64
,
f
:
int
=
128
,
w
:
int
=
64
,
k
:
int
=
3
,
f
:
int
=
128
,
s
:
int
=
1
,
k
:
int
=
3
,
d
:
int
=
1
,
s
:
int
=
1
,
p
:
int
=
1
,
d
:
int
=
1
,
use_autotune
:
bool
=
False
,
p
:
int
=
1
,
with_roller
:
bool
=
True
):
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
,
):
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
ref_prog
=
ref_program
(
S
,
P
,
D
)
ref_prog
=
ref_program
(
S
,
P
,
D
)
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
'--c'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
"--c"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
'--w'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
"--w"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
'--f'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
'--k'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
'--s'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
"--s"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
'--d'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
"--d"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
'--p'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
"--use_autotune"
,
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
args
.
with_roller
)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
667632cc
...
@@ -20,11 +20,11 @@ def tl_gemm(
...
@@ -20,11 +20,11 @@ def tl_gemm(
accum_dtype
,
accum_dtype
,
):
):
assert
in_dtype
in
[
assert
in_dtype
in
[
"
float8_e4m3
"
,
T
.
float8_e4m3
fn
,
],
"Currently only float8_e4m3 is supported"
],
"Currently only float8_e4m3 is supported"
assert
out_dtype
in
[
assert
out_dtype
in
[
"
bfloat16
"
,
T
.
bfloat16
,
"
float32
"
,
T
.
float32
,
],
"Currently only float16 and float32 are supported"
],
"Currently only float16 and float32 are supported"
group_size
=
128
group_size
=
128
...
@@ -41,18 +41,17 @@ def tl_gemm(
...
@@ -41,18 +41,17 @@ def tl_gemm(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"
float32
"
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
T
.
float32
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"
float32
"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
T
.
float32
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
Scale_C_shared
=
T
.
alloc_shared
((
block_M
),
"
float32
"
)
Scale_C_shared
=
T
.
alloc_shared
((
block_M
),
T
.
float32
)
C_local
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
C_local
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
C_local_accum
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
C_local_accum
=
T
.
alloc_fragment
(
C_shared_shape
,
accum_dtype
)
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
x_view
.
size
(
0
),
x_view
.
size
(
2
))
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc
.
zero_
()
c_acc
.
zero_
()
for
k
in
range
(
ceildiv
(
K
,
128
)):
for
k
in
range
(
ceildiv
(
K
,
128
)):
c
=
torch
.
_scaled_mm
(
c
=
torch
.
_scaled_mm
(
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
,
)
c_acc
+=
c
.
to
(
torch
.
float32
)
c_acc
+=
c
.
to
(
torch
.
float32
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
return
C
return
C
...
@@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
...
@@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def
main
():
def
main
():
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
128
,
"
float8_e4m3
"
,
"
bfloat16
"
,
"
float32
"
)
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
128
,
T
.
float8_e4m3
fn
,
T
.
bfloat16
,
T
.
float32
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
for
dtype
in
[
"
float8_e4m3
"
]:
for
dtype
in
[
T
.
float8_e4m3
fn
]:
for
out_dtype
in
[
"
bfloat16
"
,
"
float32
"
]:
for
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
]:
for
block_N
in
[
16
,
32
,
64
,
128
]:
for
block_N
in
[
16
,
32
,
64
,
128
]:
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
block_N
,
dtype
,
out_dtype
,
"
float32
"
)
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
block_N
,
dtype
,
out_dtype
,
T
.
float32
)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
View file @
667632cc
...
@@ -8,6 +8,7 @@ import argparse
...
@@ -8,6 +8,7 @@ import argparse
def
get_configs
():
def
get_configs
():
import
itertools
import
itertools
BLOCK_N
=
[
16
,
32
,
64
,
128
]
BLOCK_N
=
[
16
,
32
,
64
,
128
]
BLOCK_H
=
[
16
,
32
,
64
,
128
]
BLOCK_H
=
[
16
,
32
,
64
,
128
]
num_split
=
[
1
,
2
,
4
,
8
,
16
,
32
]
num_split
=
[
1
,
2
,
4
,
8
,
16
,
32
]
...
@@ -15,43 +16,39 @@ def get_configs():
...
@@ -15,43 +16,39 @@ def get_configs():
_configs
=
list
(
itertools
.
product
(
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
))
_configs
=
list
(
itertools
.
product
(
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
))
return
[{
return
[
"block_N"
:
c
[
0
],
{
"block_H"
:
c
[
1
],
"block_N"
:
c
[
0
],
"num_split"
:
c
[
2
],
"block_H"
:
c
[
1
],
"threads"
:
c
[
3
],
"num_split"
:
c
[
2
],
}
for
c
in
_configs
]
"threads"
:
c
[
3
],
}
for
c
in
_configs
]
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
def
flashmla_decode
(
batch
,
)
heads
,
def
flashmla_decode
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
kv_head_num
,
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
seqlen_kv
,
dtype
=
T
.
float16
dim
,
accum_dtype
=
T
.
float32
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
@
T
.
macro
@
T
.
macro
def
flash_attn
(
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
threads
=
threads
)
as
(
bx
,
by
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
...
@@ -70,27 +67,24 @@ def flashmla_decode(batch,
...
@@ -70,27 +67,24 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
0
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
0
):
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -105,20 +99,18 @@ def flashmla_decode(batch,
...
@@ -105,20 +99,18 @@ def flashmla_decode(batch,
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
Q_pe_local
=
T
.
alloc_fragment
([
block_H
,
pe_dim
],
dtype
)
Q_pe_local
=
T
.
alloc_fragment
([
block_H
,
pe_dim
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
@@ -134,8 +126,8 @@ def flashmla_decode(batch,
...
@@ -134,8 +126,8 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -148,15 +140,12 @@ def flashmla_decode(batch,
...
@@ -148,15 +140,12 @@ def flashmla_decode(batch,
T
.
copy
(
K_pe
[
bx
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
bx
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -172,14 +161,14 @@ def flashmla_decode(batch,
...
@@ -172,14 +161,14 @@ def flashmla_decode(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
@@ -189,9 +178,11 @@ def flashmla_decode(batch,
...
@@ -189,9 +178,11 @@ def flashmla_decode(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
{
})
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
...
@@ -214,26 +205,26 @@ def flashmla_decode(batch,
...
@@ -214,26 +205,26 @@ def flashmla_decode(batch,
@
T
.
prim_func
@
T
.
prim_func
def
main_split
(
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
@
T
.
prim_func
def
main_no_split
(
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
@@ -258,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
...
@@ -258,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
parser
.
add_argument
(
'
--autotune
'
,
action
=
'
store_true
'
,
help
=
'
auto tune
'
)
parser
.
add_argument
(
"
--autotune
"
,
action
=
"
store_true
"
,
help
=
"
auto tune
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
enable_autotune
=
args
.
autotune
enable_autotune
=
args
.
autotune
...
@@ -310,17 +294,7 @@ if __name__ == "__main__":
...
@@ -310,17 +294,7 @@ if __name__ == "__main__":
if
enable_autotune
:
if
enable_autotune
:
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
else
:
else
:
kernel
=
flashmla_decode
(
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
input_tensors
=
profiler
.
_get_inputs
()
input_tensors
=
profiler
.
_get_inputs
()
tilelang_output
=
kernel
(
*
input_tensors
)
tilelang_output
=
kernel
(
*
input_tensors
)
...
...
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
View file @
667632cc
...
@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
...
@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
def
ref_mla
():
...
@@ -94,8 +93,7 @@ def _mla_attn_kernel(
...
@@ -94,8 +93,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
@@ -141,9 +139,7 @@ def _mla_attn_kernel(
...
@@ -141,9 +139,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
@@ -309,24 +305,30 @@ def mla_decode_triton(
...
@@ -309,24 +305,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
def
flash_mla_triton
():
def
flash_mla_triton
():
num_kv_splits
=
32
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
q_nope
.
view
(
-
1
,
h_q
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
out_flash
=
flash_mla_triton
()
...
@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_a
,
lse_a
,
perf_a
=
baseline_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
)
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
if
target
not
in
[
"flash_mla_triton"
]:
...
@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
...
@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_b
,
lse_b
,
perf_b
=
target_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_b
...
@@ -429,26 +422,22 @@ available_targets = [
...
@@ -429,26 +422,22 @@ available_targets = [
"flash_mla_triton"
,
"flash_mla_triton"
,
]
]
shape_configs
=
[{
shape_configs
=
[
"b"
:
{
batch
,
"b"
:
batch
,
"s_q"
:
"s_q"
:
1
,
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"cache_seqlens"
:
"h_q"
:
head
,
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_kv"
:
1
,
"h_q"
:
"d"
:
512
+
64
,
head
,
"dv"
:
512
,
"h_kv"
:
"causal"
:
True
,
1
,
"dtype"
:
torch
.
float16
,
"d"
:
}
512
+
64
,
for
batch
in
[
128
]
"dv"
:
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
512
,
for
head
in
[
128
]
"causal"
:
]
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
def
get_args
():
def
get_args
():
...
@@ -470,26 +459,54 @@ if __name__ == "__main__":
...
@@ -470,26 +459,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
for
shape
in
shape_configs
:
if
args
.
all
:
if
args
.
all
:
for
target
in
available_targets
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
elif
args
.
compare
:
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
perfa
,
prefb
=
compare_ab
(
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
args
.
baseline
,
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
)
elif
args
.
one
:
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
args
.
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
View file @
667632cc
...
@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
...
@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
def
ref_mla
():
...
@@ -91,8 +90,7 @@ def _mla_attn_kernel(
...
@@ -91,8 +90,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
@@ -138,9 +136,7 @@ def _mla_attn_kernel(
...
@@ -138,9 +136,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
@@ -306,24 +302,30 @@ def mla_decode_triton(
...
@@ -306,24 +302,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
def
flash_mla_triton
():
def
flash_mla_triton
():
num_kv_splits
=
32
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
q_nope
.
view
(
-
1
,
h_q
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
out_flash
=
flash_mla_triton
()
...
@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_a
,
lse_a
,
perf_a
=
baseline_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
)
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
if
target
not
in
[
"flash_mla_triton"
]:
...
@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
...
@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_b
,
lse_b
,
perf_b
=
target_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_b
...
@@ -426,26 +419,22 @@ available_targets = [
...
@@ -426,26 +419,22 @@ available_targets = [
"flash_mla_triton"
,
"flash_mla_triton"
,
]
]
shape_configs
=
[{
shape_configs
=
[
"b"
:
{
batch
,
"b"
:
batch
,
"s_q"
:
"s_q"
:
1
,
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"cache_seqlens"
:
"h_q"
:
head
,
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_kv"
:
1
,
"h_q"
:
"d"
:
512
+
64
,
head
,
"dv"
:
512
,
"h_kv"
:
"causal"
:
True
,
1
,
"dtype"
:
torch
.
float16
,
"d"
:
}
512
+
64
,
for
batch
in
[
64
,
128
]
"dv"
:
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
512
,
for
head
in
[
128
]
"causal"
:
]
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
64
,
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
def
get_args
():
def
get_args
():
...
@@ -467,26 +456,54 @@ if __name__ == "__main__":
...
@@ -467,26 +456,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
for
shape
in
shape_configs
:
if
args
.
all
:
if
args
.
all
:
for
target
in
available_targets
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
elif
args
.
compare
:
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
perfa
,
prefb
=
compare_ab
(
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
args
.
baseline
,
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
)
elif
args
.
one
:
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
args
.
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
examples/deepseek_mla/benchmark_mla.py
View file @
667632cc
...
@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
...
@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
def
ref_mla
():
...
@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
...
@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_kv
,
d
,
dv
,
causal
,
dtype
):
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
...
@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
...
@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# pip install flashinfer-python
# pip install flashinfer-python
import
flashinfer
import
flashinfer
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
kv_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_indices
=
[]
...
@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
...
@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
mla_wrapper
.
plan
(
mla_wrapper
.
plan
(
q_indptr
,
q_indptr
,
kv_indptr
,
kv_indptr
,
...
@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
...
@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
)
)
def
flashinfer
():
def
flashinfer
():
output
,
lse
=
mla_wrapper
.
run
(
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
out_flash
,
lse_flash
=
flashinfer
()
out_flash
,
lse_flash
=
flashinfer
()
...
@@ -177,8 +168,7 @@ def _mla_attn_kernel(
...
@@ -177,8 +168,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
@@ -224,9 +214,7 @@ def _mla_attn_kernel(
...
@@ -224,9 +214,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
@@ -393,24 +381,30 @@ def mla_decode_triton(
...
@@ -393,24 +381,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
def
flash_mla_triton
():
def
flash_mla_triton
():
num_kv_splits
=
32
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
q_nope
.
view
(
-
1
,
h_q
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
out_flash
=
flash_mla_triton
()
...
@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
...
@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
dpe
=
d
-
dv
dpe
=
d
-
dv
num_kv_splits
=
1
num_kv_splits
=
1
...
@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
...
@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
)
num_kv_splits
,
block_size
)
def
flash_mla_tilelang
():
def
flash_mla_tilelang
():
out
=
kernel
(
out
=
kernel
(
...
@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_a
,
lse_a
,
perf_a
=
baseline_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
)
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
# flashinfer has a different lse return value
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
...
@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_b
,
lse_b
,
perf_b
=
target_func
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
return
bytes
/
10
**
6
/
perf_b
...
@@ -558,26 +538,22 @@ available_targets = [
...
@@ -558,26 +538,22 @@ available_targets = [
"flash_mla_triton"
,
"flash_mla_triton"
,
]
]
shape_configs
=
[{
shape_configs
=
[
"b"
:
{
batch
,
"b"
:
batch
,
"s_q"
:
"s_q"
:
1
,
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"cache_seqlens"
:
"h_q"
:
head
,
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_kv"
:
1
,
"h_q"
:
"d"
:
512
+
64
,
head
,
"dv"
:
512
,
"h_kv"
:
"causal"
:
True
,
1
,
"dtype"
:
torch
.
float16
,
"d"
:
}
512
+
64
,
for
batch
in
[
128
]
"dv"
:
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
512
,
for
head
in
[
128
]
"causal"
:
]
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
for
head
in
[
128
]]
def
get_args
():
def
get_args
():
...
@@ -599,26 +575,54 @@ if __name__ == "__main__":
...
@@ -599,26 +575,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
for
shape
in
shape_configs
:
if
args
.
all
:
if
args
.
all
:
for
target
in
available_targets
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
elif
args
.
compare
:
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
perfa
,
prefb
=
compare_ab
(
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
args
.
baseline
,
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
)
elif
args
.
one
:
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
perf
=
compare_a
(
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
args
.
target
,
shape
[
"causal"
],
shape
[
"dtype"
])
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
)
examples/deepseek_mla/example_mla_decode.py
View file @
667632cc
...
@@ -8,25 +8,26 @@ import argparse
...
@@ -8,25 +8,26 @@ import argparse
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
}
,
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
)
softmax_scale
):
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
kv_head_num
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
@
T
.
macro
@
T
.
macro
def
flash_attn
(
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
@@ -44,36 +45,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -44,36 +45,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
{
})
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
T
.
copy
(
KV
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
Q_shared
,
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -88,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -88,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
...
@@ -119,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -119,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
{
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -137,17 +131,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -137,17 +131,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -164,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -164,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
bz
,
:])
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
hid
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
hid
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
@@ -183,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -183,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
{
})
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
...
@@ -208,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -208,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
@
T
.
prim_func
def
main_split
(
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
@
T
.
prim_func
def
main_no_split
(
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
@@ -252,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
...
@@ -252,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
...
@@ -294,10 +278,9 @@ def main(
...
@@ -294,10 +278,9 @@ def main(
BLOCK_N
=
64
BLOCK_N
=
64
BLOCK_H
=
min
(
64
,
heads
//
kv_heads
)
BLOCK_H
=
min
(
64
,
heads
//
kv_heads
)
num_split
=
1
num_split
=
1
softmax_scale
=
(
dim
+
pe_dim
)
**-
0.5
softmax_scale
=
(
dim
+
pe_dim
)
**
-
0.5
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
1e-4
,
atol
=
1e-4
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
1e-4
,
atol
=
1e-4
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
...
@@ -307,12 +290,12 @@ def main(
...
@@ -307,12 +290,12 @@ def main(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
132
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
132
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/deepseek_mla/example_mla_decode_paged.py
View file @
667632cc
...
@@ -8,25 +8,17 @@ import math
...
@@ -8,25 +8,17 @@ import math
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
8
],
pass_configs
=
{
out_idx
=
[
8
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
def
mla_decode_tilelang
(
batch
,
)
h_q
,
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
(
dv
+
dpe
)
**-
0.5
softmax_scale
=
(
dv
+
dpe
)
**
-
0.5
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
kv_group_num
=
h_q
//
h_kv
kv_group_num
=
h_q
//
h_kv
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
h_kv
==
1
,
"h_kv must be 1"
assert
h_kv
==
1
,
"h_kv must be 1"
...
@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
...
@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
@
T
.
macro
@
T
.
macro
def
flash_mla_kernel
(
def
flash_mla_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"
int32
"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
T
.
int32
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
threads
=
256
)
as
(
bx
,
by
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
...
@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
...
@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
{
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -73,26 +67,20 @@ def mla_decode_tilelang(batch,
...
@@ -73,26 +67,20 @@ def mla_decode_tilelang(batch,
loop_range
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
loop_range
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
for
kr
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
for
kr
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
k
=
loop_range
-
1
-
kr
k
=
loop_range
-
1
-
kr
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
if
kr
==
0
:
if
kr
==
0
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -107,21 +95,20 @@ def mla_decode_tilelang(batch,
...
@@ -107,21 +95,20 @@ def mla_decode_tilelang(batch,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dv
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dv
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
@
T
.
macro
def
flash_mla_split_kv_kernel
(
def
flash_mla_split_kv_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"
int32
"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
dpe
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
dpe
],
dtype
)
...
@@ -139,13 +126,15 @@ def mla_decode_tilelang(batch,
...
@@ -139,13 +126,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
{
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -153,29 +142,23 @@ def mla_decode_tilelang(batch,
...
@@ -153,29 +142,23 @@ def mla_decode_tilelang(batch,
total_blocks
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
total_blocks
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
blocks_per_split
=
T
.
floordiv
(
total_blocks
,
num_split
)
blocks_per_split
=
T
.
floordiv
(
total_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
total_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
total_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
start
=
(
blocks_per_split
*
bz
+
T
.
min
(
bz
,
remaining_blocks
))
*
block_N
start
=
(
blocks_per_split
*
bz
+
T
.
min
(
bz
,
remaining_blocks
))
*
block_N
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -192,15 +175,15 @@ def mla_decode_tilelang(batch,
...
@@ -192,15 +175,15 @@ def mla_decode_tilelang(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
):
with
T
.
Kernel
(
h_q
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
h_q
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dv
],
dtype
)
po_local
=
T
.
alloc_fragment
([
dv
],
dtype
)
...
@@ -210,9 +193,11 @@ def mla_decode_tilelang(batch,
...
@@ -210,9 +193,11 @@ def mla_decode_tilelang(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
{
})
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
...
@@ -235,31 +220,30 @@ def mla_decode_tilelang(batch,
...
@@ -235,31 +220,30 @@ def mla_decode_tilelang(batch,
@
T
.
prim_func
@
T
.
prim_func
def
main_split
(
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"
int32
"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
):
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
Output_partial
)
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
@
T
.
prim_func
def
main_no_split
(
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"
int32
"
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"
int32
"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
):
flash_mla_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
Output
)
flash_mla_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
Output
)
...
@@ -280,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
...
@@ -280,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q
=
query
.
shape
[
-
2
]
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
temp_mask
=
torch
.
ones
(
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
attn_weight
+=
attn_bias
...
@@ -291,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
...
@@ -291,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_kv
,
d
,
dv
,
causal
,
dtype
):
# q: [b, s_q, h_q, d]
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
...
@@ -321,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
...
@@ -321,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return
out_torch
return
out_torch
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dv
:].
contiguous
()
dpe
=
d
-
dv
dpe
=
d
-
dv
num_kv_splits
=
1
num_kv_splits
=
1
...
@@ -337,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
...
@@ -337,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
,
softmax_scale
)
num_kv_splits
,
block_size
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
def
flash_mla_tilelang
():
def
flash_mla_tilelang
():
...
@@ -356,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
...
@@ -356,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash
=
flash_mla_tilelang
()
out_flash
=
flash_mla_tilelang
()
t
=
do_bench
(
flash_mla_tilelang
)
t
=
do_bench
(
flash_mla_tilelang
)
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_flash
,
out_ref
,
rtol
=
0.01
,
atol
=
0.01
)
torch
.
testing
.
assert_close
(
out_flash
,
out_ref
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All close"
)
print
(
"All close"
)
return
out_flash
,
t
return
out_flash
,
t
...
@@ -365,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
...
@@ -365,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--h_q
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
"
--h_q
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
'
--h_kv
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
"
--h_kv
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
'
--cache_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv cache context length
'
)
parser
.
add_argument
(
"
--cache_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv cache context length
"
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
576
,
help
=
'
query/key head dim, d = dv + dpe
'
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
576
,
help
=
"
query/key head dim, d = dv + dpe
"
)
parser
.
add_argument
(
'
--dv
'
,
type
=
int
,
default
=
512
,
help
=
'
value head dim
'
)
parser
.
add_argument
(
"
--dv
"
,
type
=
int
,
default
=
512
,
help
=
"
value head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
b
,
h_q
,
h_kv
,
cache_seqlen
,
d
,
dv
=
args
.
batch
,
args
.
h_q
,
args
.
h_kv
,
args
.
cache_seqlen
,
args
.
d
,
args
.
dv
b
,
h_q
,
h_kv
,
cache_seqlen
,
d
,
dv
=
args
.
batch
,
args
.
h_q
,
args
.
h_kv
,
args
.
cache_seqlen
,
args
.
d
,
args
.
dv
...
@@ -379,9 +356,7 @@ if __name__ == "__main__":
...
@@ -379,9 +356,7 @@ if __name__ == "__main__":
s_q
=
1
# for decode, s_q = 1
s_q
=
1
# for decode, s_q = 1
block_size
=
64
block_size
=
64
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
dtype
=
torch
.
int32
,
device
=
device
)
dtype
=
torch
.
int32
,
device
=
device
)
dpe
=
d
-
dv
dpe
=
d
-
dv
causal
=
True
causal
=
True
...
@@ -393,12 +368,11 @@ if __name__ == "__main__":
...
@@ -393,12 +368,11 @@ if __name__ == "__main__":
total_flops
=
s_q
*
total_seqlens
*
h_q
*
d
*
2
total_flops
=
s_q
*
total_seqlens
*
h_q
*
d
*
2
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
,
dtype
=
dtype
,
device
=
device
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
,
dtype
=
dtype
,
device
=
device
)
out_flash
,
latency
=
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
out_flash
,
latency
=
run_tilelang_mla
(
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
examples/deepseek_mla/example_mla_decode_persistent.py
View file @
667632cc
...
@@ -9,13 +9,15 @@ import argparse
...
@@ -9,13 +9,15 @@ import argparse
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
kv_head_num
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
...
@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
@
T
.
prim_func
def
main_split_persistent
(
def
main_split_persistent
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
sm_num
,
threads
=
256
)
as
(
block_id
):
with
T
.
Kernel
(
sm_num
,
threads
=
256
)
as
(
block_id
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
{
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
total_tiles
=
batch
*
(
heads
//
min
(
block_H
,
kv_group_num
))
*
num_split
total_tiles
=
batch
*
(
heads
//
min
(
block_H
,
kv_group_num
))
*
num_split
...
@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
if
bid
<
batch
and
hid
*
VALID_BLOCK_H
<
heads
and
sid
<
num_split
:
if
bid
<
batch
and
hid
*
VALID_BLOCK_H
<
heads
and
sid
<
num_split
:
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -83,24 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -83,24 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
Q_shared
,
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_max
[
i
]
*
scale
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
@@ -115,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -115,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
# T.copy(acc_o, O_shared)
# T.copy(acc_o, O_shared)
T
.
copy
(
T
.
copy
(
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
T
.
sync_grid
()
T
.
sync_grid
()
waves
=
T
.
ceildiv
(
heads
*
batch
,
sm_num
)
waves
=
T
.
ceildiv
(
heads
*
batch
,
sm_num
)
...
@@ -165,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
...
@@ -165,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
qk_flops
=
2
*
batch
*
heads
*
kv_ctx
*
(
dim
+
pe_dim
)
qk_flops
=
2
*
batch
*
heads
*
kv_ctx
*
(
dim
+
pe_dim
)
...
...
Prev
1
2
3
4
5
6
7
8
9
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment