Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1651 additions
and
1690 deletions
+1651
-1690
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
...les/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
+78
-108
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
+91
-74
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
+91
-74
examples/deepseek_mla/benchmark_mla.py
examples/deepseek_mla/benchmark_mla.py
+101
-97
examples/deepseek_mla/example_mla_decode.py
examples/deepseek_mla/example_mla_decode.py
+80
-101
examples/deepseek_mla/example_mla_decode_paged.py
examples/deepseek_mla/example_mla_decode_paged.py
+98
-128
examples/deepseek_mla/example_mla_decode_persistent.py
examples/deepseek_mla/example_mla_decode_persistent.py
+41
-57
examples/deepseek_mla/example_mla_decode_ws.py
examples/deepseek_mla/example_mla_decode_ws.py
+113
-132
examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
...es/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
+38
-47
examples/deepseek_mla/torch_refs.py
examples/deepseek_mla/torch_refs.py
+16
-13
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
+188
-230
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
+107
-101
examples/deepseek_nsa/example_tilelang_nsa_decode.py
examples/deepseek_nsa/example_tilelang_nsa_decode.py
+22
-28
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
+25
-38
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
+94
-84
examples/deepseek_nsa/example_triton_nsa_bwd.py
examples/deepseek_nsa/example_triton_nsa_bwd.py
+219
-135
examples/deepseek_nsa/example_triton_nsa_fwd.py
examples/deepseek_nsa/example_triton_nsa_fwd.py
+72
-52
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
+89
-70
examples/deepseek_nsa/reference.py
examples/deepseek_nsa/reference.py
+52
-61
examples/deepseek_v32/fp8_lighting_indexer.py
examples/deepseek_v32/fp8_lighting_indexer.py
+36
-60
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
View file @
29051439
...
...
@@ -8,6 +8,7 @@ import argparse
def
get_configs
():
import
itertools
BLOCK_N
=
[
16
,
32
,
64
,
128
]
BLOCK_H
=
[
16
,
32
,
64
,
128
]
num_split
=
[
1
,
2
,
4
,
8
,
16
,
32
]
...
...
@@ -15,30 +16,26 @@ def get_configs():
_configs
=
list
(
itertools
.
product
(
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
))
return
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"threads"
:
c
[
3
],
}
for
c
in
_configs
]
return
[
{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"threads"
:
c
[
3
],
}
for
c
in
_configs
]
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashmla_decode
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashmla_decode
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
threads
=
128
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
...
...
@@ -47,11 +44,11 @@ def flashmla_decode(batch,
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
threads
=
threads
)
as
(
bx
,
by
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
...
...
@@ -70,24 +67,19 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
0
):
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -107,20 +99,18 @@ def flashmla_decode(batch,
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
acc_o
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_local
=
T
.
alloc_fragment
([
block_H
,
dim
],
dtype
)
Q_pe_local
=
T
.
alloc_fragment
([
block_H
,
pe_dim
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -136,8 +126,8 @@ def flashmla_decode(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_local
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_local
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -150,12 +140,7 @@ def flashmla_decode(batch,
T
.
copy
(
K_pe
[
bx
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_local
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_pe_local
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -176,14 +161,14 @@ def flashmla_decode(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
...
@@ -193,9 +178,11 @@ def flashmla_decode(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -218,26 +205,26 @@ def flashmla_decode(batch,
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -262,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
'
--autotune
'
,
action
=
'
store_true
'
,
help
=
'
auto tune
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
parser
.
add_argument
(
"
--autotune
"
,
action
=
"
store_true
"
,
help
=
"
auto tune
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
enable_autotune
=
args
.
autotune
...
...
@@ -314,17 +294,7 @@ if __name__ == "__main__":
if
enable_autotune
:
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
else
:
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
kernel
=
flashmla_decode
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
threads
=
threads
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
input_tensors
=
profiler
.
_get_inputs
()
tilelang_output
=
kernel
(
*
input_tensors
)
...
...
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
View file @
29051439
...
...
@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -94,8 +93,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -141,9 +139,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -309,24 +305,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
...
...
@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -429,26 +422,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -470,26 +459,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
View file @
29051439
...
...
@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -91,8 +90,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -138,9 +136,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -306,24 +302,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash_mla_triton"
]:
...
...
@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -426,26 +419,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
64
,
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
64
,
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -467,26 +456,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/benchmark_mla.py
View file @
29051439
...
...
@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
def
ref_mla
():
...
...
@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
blocked_v
=
blocked_k
[...,
:
dv
]
...
...
@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# pip install flashinfer-python
import
flashinfer
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
kv_indptr
=
[
0
]
kv_indices
=
[]
...
...
@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
mla_wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
),
backend
=
"fa3"
)
mla_wrapper
.
plan
(
q_indptr
,
kv_indptr
,
...
...
@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
)
def
flashinfer
():
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
,
blocked_k_pe
,
return_lse
=
True
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
out_flash
,
lse_flash
=
flashinfer
()
...
...
@@ -177,8 +168,7 @@ def _mla_attn_kernel(
offs_d_ckv
=
tl
.
arange
(
0
,
HEAD_DIM_CKV
)
cur_head
=
cur_head_id
*
BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
offs_q_nope
=
cur_batch
*
stride_q_nope_bs
+
cur_head
[:,
None
]
*
stride_q_nope_h
+
offs_d_ckv
[
None
,
:]
q_nope
=
tl
.
load
(
Q_nope
+
offs_q_nope
)
offs_d_kpe
=
tl
.
arange
(
0
,
HEAD_DIM_KPE
)
...
...
@@ -224,9 +214,7 @@ def _mla_attn_kernel(
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
offs_o
=
cur_batch
*
stride_o_b
+
cur_head
[:,
None
]
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
offs_d_ckv
[
None
,
:]
tl
.
store
(
O
+
offs_o
,
acc
/
e_sum
[:,
None
])
offs_o_1
=
cur_batch
*
stride_o_b
+
cur_head
*
stride_o_h
+
split_kv_id
*
stride_o_s
+
HEAD_DIM_CKV
tl
.
store
(
O
+
offs_o_1
,
e_max
+
tl
.
log
(
e_sum
))
...
...
@@ -393,24 +381,30 @@ def mla_decode_triton(
@
torch
.
inference_mode
()
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_triton
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
blocked_v
=
blocked_k
[...,
:
dv
]
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
def
flash_mla_triton
():
num_kv_splits
=
32
o
=
torch
.
empty
([
b
*
s_q
,
h_q
,
dv
])
attn_logits
=
torch
.
empty
([
b
*
s_q
,
h_q
,
num_kv_splits
,
dv
+
1
])
mla_decode_triton
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
)
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
blocked_k_nope
.
view
(
-
1
,
dv
),
blocked_k_pe
.
view
(
-
1
,
d
-
dv
),
o
,
block_table
,
cache_seqlens
,
attn_logits
,
num_kv_splits
,
1
/
math
.
sqrt
(
d
),
block_size
,
)
return
o
.
view
([
b
,
s_q
,
h_q
,
dv
])
out_flash
=
flash_mla_triton
()
...
...
@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@
torch
.
inference_mode
()
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_flash_mla_tilelang
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dpe
=
d
-
dv
num_kv_splits
=
1
...
...
@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
)
def
flash_mla_tilelang
():
out
=
kernel
(
...
...
@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_a
,
lse_a
,
perf_a
=
baseline_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
print
(
f
"
{
target
}
:
{
b
=
}
,
{
s_q
=
}
, mean_seqlens=
{
cache_seqlens
.
float
().
mean
()
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
dtype
=
}
"
)
torch
.
set_default_dtype
(
dtype
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
...
...
@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_b
,
lse_b
,
perf_b
=
target_func
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_b
...
...
@@ -558,26 +538,22 @@ available_targets = [
"flash_mla_triton"
,
]
shape_configs
=
[{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
for
head
in
[
128
]]
shape_configs
=
[
{
"b"
:
batch
,
"s_q"
:
1
,
"cache_seqlens"
:
torch
.
tensor
([
seqlen
+
2
*
i
for
i
in
range
(
batch
)],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"h_q"
:
head
,
"h_kv"
:
1
,
"d"
:
512
+
64
,
"dv"
:
512
,
"causal"
:
True
,
"dtype"
:
torch
.
float16
,
}
for
batch
in
[
128
]
for
seqlen
in
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
for
head
in
[
128
]
]
def
get_args
():
...
...
@@ -599,26 +575,54 @@ if __name__ == "__main__":
for
shape
in
shape_configs
:
if
args
.
all
:
for
target
in
available_targets
:
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
elif
args
.
compare
:
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perfa
:.
0
f
}
\n
'
f
"
{
args
.
baseline
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perfa
:.
0
f
}
\n
"
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
prefb
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
prefb
:.
0
f
}
\n
"
)
elif
args
.
one
:
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
],
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"
cache_seqlens
"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"
h_q
"
]
}
,
{
perf
:.
0
f
}
\n
'
f
"
{
args
.
target
}
,
{
shape
[
'b'
]
}
,
{
shape
[
'
cache_seqlens
'
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
'
h_q
'
]
}
,
{
perf
:.
0
f
}
\n
"
)
examples/deepseek_mla/example_mla_decode.py
View file @
29051439
...
...
@@ -8,11 +8,12 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
}
,
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
...
...
@@ -22,11 +23,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -44,33 +45,24 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
T
.
copy
(
KV
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
KV
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
clear_accum
=
True
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -90,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bid
,
hid
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
...
...
@@ -121,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -139,14 +131,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -168,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
hid
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
...
@@ -187,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -212,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -256,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -298,10 +278,9 @@ def main(
BLOCK_N
=
64
BLOCK_H
=
min
(
64
,
heads
//
kv_heads
)
num_split
=
1
softmax_scale
=
(
dim
+
pe_dim
)
**-
0.5
softmax_scale
=
(
dim
+
pe_dim
)
**
-
0.5
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
1e-4
,
atol
=
1e-4
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -311,12 +290,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
132
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
132
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/deepseek_mla/example_mla_decode_paged.py
View file @
29051439
...
...
@@ -8,22 +8,14 @@ import math
@
tilelang
.
jit
(
out_idx
=
[
8
],
pass_configs
=
{
out_idx
=
[
8
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
},
)
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
if
softmax_scale
is
None
:
softmax_scale
=
(
dv
+
dpe
)
**-
0.5
softmax_scale
=
(
dv
+
dpe
)
**
-
0.5
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
...
...
@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
@
T
.
macro
def
flash_mla_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"int32"
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"int32"
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
threads
=
256
)
as
(
bx
,
by
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
...
...
@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -73,25 +67,17 @@ def mla_decode_tilelang(batch,
loop_range
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
for
kr
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
k
=
loop_range
-
1
-
kr
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
kv_start
=
BLOCK_TABLE
[
bx
,
(
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
if
kr
==
0
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
...
...
@@ -109,21 +95,20 @@ def mla_decode_tilelang(batch,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dv
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
@
T
.
macro
def
flash_mla_split_kv_kernel
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
CACHE_SEQLENS
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
h_q
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dv
],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
dpe
],
dtype
)
...
...
@@ -141,13 +126,15 @@ def mla_decode_tilelang(batch,
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -155,28 +142,20 @@ def mla_decode_tilelang(batch,
total_blocks
=
T
.
ceildiv
(
CACHE_SEQLENS
[
bx
],
block_N
)
blocks_per_split
=
T
.
floordiv
(
total_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
total_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
bz
<
remaining_blocks
,
1
,
0
)
start
=
(
blocks_per_split
*
bz
+
T
.
min
(
bz
,
remaining_blocks
))
*
block_N
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
kv_start
=
BLOCK_TABLE
[
bx
,
(
start
+
k
*
block_N
)
//
block_size
]
*
block_size
+
(
k
*
block_N
)
%
block_size
T
.
copy
(
KV
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
kv_start
:
kv_start
+
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
start
+
k
*
block_N
+
j
>=
CACHE_SEQLENS
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
...
...
@@ -196,15 +175,15 @@ def mla_decode_tilelang(batch,
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
logsum
,
glse
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
bz
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
with
T
.
Kernel
(
h_q
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dv
],
dtype
)
...
...
@@ -214,9 +193,11 @@ def mla_decode_tilelang(batch,
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -239,31 +220,30 @@ def mla_decode_tilelang(batch,
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
Output_partial
)
flash_mla_split_kv_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
h_q
,
dpe
],
dtype
),
KV
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dv
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
*
max_seqlen_pad
,
h_kv
,
dpe
],
dtype
),
block_table
:
T
.
Tensor
([
batch
,
max_seqlen_pad
//
block_size
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
h_q
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
h_q
,
num_split
,
dv
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
h_q
,
dv
],
dtype
),
):
flash_mla_kernel
(
Q
,
Q_pe
,
KV
,
K_pe
,
block_table
,
cache_seqlens
,
Output
)
...
...
@@ -284,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
query
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
...
...
@@ -295,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@
torch
.
inference_mode
()
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
...
...
@@ -325,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return
out_torch
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
def
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
assert
d
>
dv
,
"mla with rope dim should be larger than no rope dim"
q_nope
,
q_pe
=
q
[...,
:
dv
].
contiguous
(),
q
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
blocked_k_nope
,
blocked_k_pe
=
blocked_k
[...,
:
dv
].
contiguous
(),
blocked_k
[...,
dv
:].
contiguous
()
dpe
=
d
-
dv
num_kv_splits
=
1
...
...
@@ -341,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
,
softmax_scale
)
kernel
=
mla_decode_tilelang
(
b
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
BLOCK_N
,
BLOCK_H
,
num_kv_splits
,
block_size
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
def
flash_mla_tilelang
():
...
...
@@ -360,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash
=
flash_mla_tilelang
()
t
=
do_bench
(
flash_mla_tilelang
)
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_ref
=
run_torch_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_flash
,
out_ref
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All close"
)
return
out_flash
,
t
...
...
@@ -369,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--h_q
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--h_kv
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--cache_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv cache context length
'
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
576
,
help
=
'
query/key head dim, d = dv + dpe
'
)
parser
.
add_argument
(
'
--dv
'
,
type
=
int
,
default
=
512
,
help
=
'
value head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--h_q
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--h_kv
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--cache_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv cache context length
"
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
576
,
help
=
"
query/key head dim, d = dv + dpe
"
)
parser
.
add_argument
(
"
--dv
"
,
type
=
int
,
default
=
512
,
help
=
"
value head dim
"
)
args
=
parser
.
parse_args
()
b
,
h_q
,
h_kv
,
cache_seqlen
,
d
,
dv
=
args
.
batch
,
args
.
h_q
,
args
.
h_kv
,
args
.
cache_seqlen
,
args
.
d
,
args
.
dv
...
...
@@ -383,9 +356,7 @@ if __name__ == "__main__":
s_q
=
1
# for decode, s_q = 1
block_size
=
64
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
dtype
=
torch
.
int32
,
device
=
device
)
cache_seqlens
=
torch
.
tensor
([
cache_seqlen
+
2
*
i
for
i
in
range
(
b
)],
dtype
=
torch
.
int32
,
device
=
device
)
dpe
=
d
-
dv
causal
=
True
...
...
@@ -397,12 +368,11 @@ if __name__ == "__main__":
total_flops
=
s_q
*
total_seqlens
*
h_q
*
d
*
2
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
,
dtype
=
dtype
,
device
=
device
)
out_flash
,
latency
=
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
out_flash
,
latency
=
run_tilelang_mla
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
examples/deepseek_mla/example_mla_decode_persistent.py
View file @
29051439
...
...
@@ -9,11 +9,13 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
...
...
@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_split_persistent
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
sm_num
,
threads
=
256
)
as
(
block_id
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
use_swizzle
(
10
)
total_tiles
=
batch
*
(
heads
//
min
(
block_H
,
kv_group_num
))
*
num_split
...
...
@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
if
bid
<
batch
and
hid
*
VALID_BLOCK_H
<
heads
and
sid
<
num_split
:
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -83,26 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -117,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
])
# T.copy(acc_o, O_shared)
T
.
copy
(
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
T
.
copy
(
acc_o
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
sid
,
:])
T
.
sync_grid
()
waves
=
T
.
ceildiv
(
heads
*
batch
,
sm_num
)
...
...
@@ -167,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
qk_flops
=
2
*
batch
*
heads
*
kv_ctx
*
(
dim
+
pe_dim
)
...
...
examples/deepseek_mla/example_mla_decode_ws.py
View file @
29051439
...
...
@@ -13,14 +13,19 @@ import argparse
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
compile_flags
=
[
"-O3"
,
"-Wno-deprecated-declarations"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--ptxas-options=-v,--register-usage-level=10"
,
"-DNDEBUG"
"-O3"
,
"-Wno-deprecated-declarations"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--ptxas-options=-v,--register-usage-level=10"
,
"-DNDEBUG"
,
],
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
,
softmax_scale
):
sm_scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
...
...
@@ -30,11 +35,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
384
)
as
(
hid
,
bid
):
Q_shared_l
=
T
.
alloc_shared
([
block_H
,
dim
//
2
],
dtype
)
...
...
@@ -75,16 +80,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_tail_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_tail_shared
)
T
.
barrier_arrive
(
bar_q
)
if
tx
<
128
:
T
.
set_max_nreg
(
240
,
1
)
T
.
fill
(
sumexp
,
0
)
T
.
fill
(
m_i
,
-
2
**
30
)
# avoid -inf - inf to cause nan
T
.
fill
(
m_i
,
-
(
2
**
30
)
)
# avoid -inf - inf to cause nan
T
.
fill
(
acc_o_l
,
0
)
T
.
barrier_wait
(
bar_q
,
0
)
...
...
@@ -166,8 +171,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
h_i
in
T
.
Parallel
(
block_H
):
sumexp
[
h_i
]
=
T
.
log2
(
sumexp
[
h_i
])
+
m_i
[
h_i
]
*
sm_scale
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
])
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
])
elif
tx
>=
128
and
tx
<
256
:
T
.
set_max_nreg
(
168
,
1
)
...
...
@@ -197,8 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
h_i
,
d_i
]
/=
sum_exp_shared
[
h_i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
])
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
])
elif
tx
>=
256
:
# producer
...
...
@@ -211,19 +214,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_0_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_0
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
K_tail_shared_0
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
T
.
cp_async_barrier_noinc
(
bar_k_0_ready
[
0
])
# Buffer 1
...
...
@@ -233,33 +234,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_1_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_1
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
K_tail_shared_1
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
T
.
cp_async_barrier_noinc
(
bar_k_1_ready
[
0
])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
384
)
as
(
bid
,
hid
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
384
)
as
(
bid
,
hid
,
bz
):
Q_shared_l
=
T
.
alloc_shared
([
block_H
,
dim
//
2
],
dtype
)
Q_shared_r
=
T
.
alloc_shared
([
block_H
,
dim
//
2
],
dtype
)
Q_tail_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
...
...
@@ -298,16 +295,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_tail_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
0
:
dim
//
2
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
dim
//
2
:
dim
],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_tail_shared
)
T
.
barrier_arrive
(
bar_q
)
if
tx
<
128
:
T
.
set_max_nreg
(
240
,
1
)
T
.
fill
(
sumexp
,
0
)
T
.
fill
(
m_i
,
-
2
**
30
)
# avoid -inf - inf to cause nan
T
.
fill
(
m_i
,
-
(
2
**
30
)
)
# avoid -inf - inf to cause nan
T
.
fill
(
acc_o_l
,
0
)
T
.
barrier_wait
(
bar_q
,
0
)
...
...
@@ -389,10 +386,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
h_i
in
T
.
Parallel
(
block_H
):
sumexp
[
h_i
]
=
T
.
log2
(
sumexp
[
h_i
])
+
m_i
[
h_i
]
*
sm_scale
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
0
:
dim
//
2
])
T
.
copy
(
sumexp
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
T
.
copy
(
O_shared_l
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
0
:
dim
//
2
])
T
.
copy
(
sumexp
,
glse
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
])
elif
tx
>=
128
and
tx
<
256
:
T
.
set_max_nreg
(
168
,
1
)
...
...
@@ -422,9 +417,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
h_i
,
d_i
]
/=
sum_exp_shared
[
h_i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
dim
//
2
:
dim
])
T
.
copy
(
O_shared_r
,
Output_partial
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
bz
,
dim
//
2
:
dim
])
elif
tx
>=
256
:
# producer
...
...
@@ -433,54 +426,48 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
# Buffer 0
T
.
barrier_wait
(
bar_k_0_free
[
0
],
((
i_i
&
1
)
^
1
))
for
r
in
T
.
serial
(
4
):
kv_indices
=
(
seqlen_kv
//
num_split
)
*
bz
+
(
i_i
*
2
)
*
block_N
+
r
*
16
+
(
tx
-
256
)
//
8
kv_indices
=
(
seqlen_kv
//
num_split
)
*
bz
+
(
i_i
*
2
)
*
block_N
+
r
*
16
+
(
tx
-
256
)
//
8
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_0_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_0_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_0
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
K_tail_shared_0
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
T
.
cp_async_barrier_noinc
(
bar_k_0_ready
[
0
])
# Buffer 1
T
.
barrier_wait
(
bar_k_1_free
[
0
],
((
i_i
&
1
)
^
1
))
for
r
in
T
.
serial
(
4
):
kv_indices
=
(
seqlen_kv
//
num_split
)
*
bz
+
(
i_i
*
2
+
1
)
*
block_N
+
r
*
16
+
(
tx
-
256
)
//
8
kv_indices
=
(
seqlen_kv
//
num_split
)
*
bz
+
(
i_i
*
2
+
1
)
*
block_N
+
r
*
16
+
(
tx
-
256
)
//
8
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_1_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
KV_shared_1_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
bid
,
kv_indices
,
cur_kv_head
,
dim
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_1
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
K_tail_shared_1
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
K_pe
[
bid
,
kv_indices
,
cur_kv_head
,
(
tx
-
256
)
%
8
*
8
+
v
]
T
.
cp_async_barrier_noinc
(
bar_k_1_ready
[
0
])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
hid
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
...
@@ -490,9 +477,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -515,26 +504,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -559,31 +548,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -601,10 +583,9 @@ def main(
BLOCK_N
=
64
BLOCK_H
=
min
(
64
,
heads
//
kv_heads
)
num_split
=
1
softmax_scale
=
(
dim
+
pe_dim
)
**-
0.5
softmax_scale
=
(
dim
+
pe_dim
)
**
-
0.5
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
kernel
=
flashattn
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
,
BLOCK_N
,
BLOCK_H
,
num_split
,
softmax_scale
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
profiler
.
assert_allclose
(
ref_program
,
rtol
=
1e-4
,
atol
=
1e-4
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -614,12 +595,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
132
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
132
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
View file @
29051439
...
...
@@ -8,11 +8,13 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
q_dtype
=
"float8_e4m3"
accum_dtype
=
"float"
...
...
@@ -22,11 +24,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
q_dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
q_dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
threads
=
256
)
as
(
bx
,
by
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -46,31 +48,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
by
//
(
kv_group_num
//
block_H
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
disable_warp_group_reg_alloc
()
loop_range
=
T
.
ceildiv
(
seqlen_kv
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
qKV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
KV
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
qKV_shared
)
T
.
copy
(
K_pe
[
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared
)
T
.
copy
(
qKV_shared
,
KV_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -90,7 +88,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bx
,
by
*
VALID_BLOCK_H
:
(
by
+
1
)
*
VALID_BLOCK_H
,
:])
return
main_no_split
...
...
@@ -108,42 +106,35 @@ def ref_program(q, q_pe, kv, k_pe):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
qk_flops
=
2
*
batch
*
heads
*
kv_ctx
*
(
dim
+
pe_dim
)
...
...
examples/deepseek_mla/torch_refs.py
View file @
29051439
...
...
@@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
block_N
=
64
seqlen_kv
=
KV
.
size
(
1
)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
...
@@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bhd,bkhd->bhk'
,
Q_
,
KV_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, nheads, block_N]
acc_s
=
torch
.
einsum
(
"bhd,bkhd->bhk"
,
Q_
,
KV_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, nheads, block_N]
acc_s
+=
torch
.
einsum
(
'bhd,bkhd->bhk'
,
Q_pe_
,
K_pe_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bhd,bkhd->bhk"
,
Q_pe_
,
K_pe_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
...
...
@@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_o
+=
torch
.
einsum
(
'bhk,bkhd->bhd'
,
acc_s_cast
,
KV_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bhk,bkhd->bhd"
,
acc_s_cast
,
KV_
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o
/=
logsum
[:,
:,
None
]
...
...
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
View file @
29051439
...
...
@@ -14,21 +14,44 @@ from fla.ops.utils import prepare_token_indices
from
fla.utils
import
autocast_custom_fwd
,
contiguous
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
),
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
...
@@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_slc
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_slc
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i
in
range
(
NS
):
i_s
=
tl
.
load
(
block_indices
+
i
).
to
(
tl
.
int32
)
*
BS
...
...
@@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc
=
tl
.
load
(
p_v_slc
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_slc
=
tl
.
dot
(
b_q
,
b_k_slc
)
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
'
-inf
'
))
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
"
-inf
"
))
# [G]
b_m_slc
,
b_mp_slc
=
tl
.
maximum
(
b_m_slc
,
tl
.
max
(
b_s_slc
,
1
)),
b_m_slc
...
...
@@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -100,8 +120,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices
=
prepare_token_indices
(
offsets
)
if
offsets
is
not
None
else
None
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
block_indices
=
block_indices
ctx
.
block_size
=
block_size
...
...
@@ -172,7 +191,6 @@ def parallel_nsa_fwd(
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -195,7 +213,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
window_size
,
scale
=
scale
,
offsets
=
offsets
,
token_indices
=
token_indices
)
token_indices
=
token_indices
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o_slc
,
lse_slc
,
o_swa
,
lse_swa
)
ctx
.
block_indices
=
block_indices
ctx
.
block_counts
=
block_counts
...
...
@@ -207,18 +226,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return
o_slc
.
to
(
q
.
dtype
),
o_swa
.
to
(
q
.
dtype
)
if
o_swa
is
not
None
else
o_swa
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
o
=
torch
.
addcmul
(
o_slc
*
g_slc
.
unsqueeze
(
-
1
),
o_swa
,
g_swa
.
unsqueeze
(
-
1
))
else
:
o
=
o_slc
*
g_slc
.
unsqueeze
(
-
1
)
if
head_first
:
o
=
rearrange
(
o
,
'
b t h d -> b h t d
'
)
o
=
rearrange
(
o
,
"
b t h d -> b h t d
"
)
return
o
def
naive_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
naive_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
raise
RuntimeError
(
"Sequences with variable lengths are not supported for head-first mode"
)
raise
RuntimeError
(
"Sequences with variable lengths are not supported for head-first mode"
)
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
dtype
=
q
.
dtype
G
=
q
.
shape
[
2
]
//
k
.
shape
[
2
]
BS
=
block_size
S
=
block_indices
.
shape
[
-
1
]
k
,
v
,
block_indices
=
(
repeat
(
x
,
'
b t h d -> b t (h g) d
'
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
k
,
v
,
block_indices
=
(
repeat
(
x
,
"
b t h d -> b t (h g) d
"
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
repeat
(
block_counts
,
'
b t h -> b t (h g)
'
,
g
=
G
)
block_counts
=
repeat
(
block_counts
,
"
b t h -> b t (h g)
"
,
g
=
G
)
c
=
torch
.
arange
(
S
).
repeat_interleave
(
BS
).
unsqueeze
(
1
).
expand
(
-
1
,
q
.
shape
[
2
]).
to
(
q
.
device
)
q
,
k
,
v
=
map
(
lambda
x
:
x
.
float
(),
(
q
,
k
,
v
))
...
...
@@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor,
if
cu_seqlens
is
None
:
varlen
=
False
B
,
T
=
q
.
shape
[:
2
]
cu_seqlens
=
torch
.
cat
(
[
block_indices
.
new_tensor
(
range
(
0
,
B
*
T
,
T
)),
block_indices
.
new_tensor
([
B
*
T
])])
cu_seqlens
=
torch
.
cat
([
block_indices
.
new_tensor
(
range
(
0
,
B
*
T
,
T
)),
block_indices
.
new_tensor
([
B
*
T
])])
for
i
in
range
(
len
(
cu_seqlens
)
-
1
):
if
not
varlen
:
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
q
[
i
],
k
[
i
],
v
[
i
],
g_slc
[
i
],
g_swa
[
i
],
block_indices
[
i
]
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
q
[
i
],
k
[
i
],
v
[
i
],
g_slc
[
i
],
g_swa
[
i
],
block_indices
[
i
]
if
isinstance
(
block_counts
,
torch
.
Tensor
):
s_b
=
block_counts
[
i
]
else
:
...
...
@@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor,
else
:
T
=
cu_seqlens
[
i
+
1
]
-
cu_seqlens
[
i
]
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
map
(
lambda
x
:
x
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]],
(
q
,
k
,
v
,
g_slc
,
g_swa
,
block_indices
)
)
lambda
x
:
x
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]],
(
q
,
k
,
v
,
g_slc
,
g_swa
,
block_indices
)
)
if
isinstance
(
block_counts
,
torch
.
Tensor
):
s_b
=
block_counts
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]]
s_b
=
block_counts
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]]
else
:
s_b
=
block_counts
...
...
@@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor,
else
:
s_i
=
s_b
# [S*BS, HQ, -1]
k_i_slc
,
v_i_slc
=
map
(
lambda
x
:
x
.
gather
(
0
,
i_i
.
clamp
(
0
,
T
-
1
).
unsqueeze
(
-
1
).
expand
(
*
i_i
.
shape
,
x
.
shape
[
-
1
])),
(
k_b
,
v_b
))
k_i_slc
,
v_i_slc
=
map
(
lambda
x
:
x
.
gather
(
0
,
i_i
.
clamp
(
0
,
T
-
1
).
unsqueeze
(
-
1
).
expand
(
*
i_i
.
shape
,
x
.
shape
[
-
1
])),
(
k_b
,
v_b
))
# [S*BS, HQ]
attn_slc
=
torch
.
einsum
(
'h d, n h d -> n h'
,
q_i
,
k_i_slc
).
masked_fill
(
torch
.
logical_or
(
i_i
<
0
,
i_i
>
i_q
)
|
(
c
>=
s_i
if
block_counts
is
not
None
else
False
),
float
(
'-inf'
)).
softmax
(
0
)
attn_slc
=
(
torch
.
einsum
(
"h d, n h d -> n h"
,
q_i
,
k_i_slc
)
.
masked_fill
(
torch
.
logical_or
(
i_i
<
0
,
i_i
>
i_q
)
|
(
c
>=
s_i
if
block_counts
is
not
None
else
False
),
float
(
"-inf"
))
.
softmax
(
0
)
)
if
not
varlen
:
o_slc
[
i
,
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
o_slc
[
i
,
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
else
:
o_slc
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
o_slc
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
if
window_size
>
0
:
k_i_swa
,
v_i_swa
=
map
(
lambda
x
:
x
[
max
(
0
,
i_q
-
window_size
+
1
):
i_q
+
1
],
(
k_b
,
v_b
))
attn_swa
=
torch
.
einsum
(
'h d, n h d -> n h'
,
q_i
,
k_i_swa
).
softmax
(
0
)
k_i_swa
,
v_i_swa
=
map
(
lambda
x
:
x
[
max
(
0
,
i_q
-
window_size
+
1
)
:
i_q
+
1
],
(
k_b
,
v_b
))
attn_swa
=
torch
.
einsum
(
"h d, n h d -> n h"
,
q_i
,
k_i_swa
).
softmax
(
0
)
if
not
varlen
:
o_swa
[
i
,
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
o_swa
[
i
,
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
else
:
o_swa
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
o_swa
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
if
head_first
:
o_slc
=
rearrange
(
o_slc
,
'
b t h d -> b h t d
'
)
o_swa
=
rearrange
(
o_swa
,
'
b t h d -> b h t d
'
)
o_slc
=
rearrange
(
o_slc
,
"
b t h d -> b h t d
"
)
o_swa
=
rearrange
(
o_swa
,
"
b t h d -> b h t d
"
)
return
o_slc
.
to
(
dtype
)
+
o_swa
.
to
(
dtype
)
if
o_swa
is
not
None
else
o_slc
.
to
(
dtype
)
def
get_configs
():
import
itertools
iter_params
=
dict
(
block_T
=
[
128
,
256
,
512
],
num_stages
=
[
0
,
1
,
2
,
4
,
5
],
threads
=
[
32
,
64
,
128
,
256
,
512
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
tilelang_sparse_attention
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
,
block_T
=
128
,
num_stages
=
2
,
threads
=
32
):
}
)
def
tilelang_sparse_attention
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
,
block_T
=
128
,
num_stages
=
2
,
threads
=
32
):
if
scale
is
None
:
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
else
:
scale
=
scale
*
1.44269504
# log2(e)
...
...
@@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch,
@
T
.
prim_func
def
tilelang_sparse_attention
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
seq_len
,
NV
,
batch
*
head_kv
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
G
,
BK
],
dtype
)
...
...
@@ -520,7 +523,7 @@ def tilelang_sparse_attention(batch,
i_b
,
i_h
=
i_bh
//
head_kv
,
i_bh
%
head_kv
NS
=
S
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
...
...
@@ -530,21 +533,15 @@ def tilelang_sparse_attention(batch,
i_s
=
BlockIndices
[
i_b
,
i_t
,
i_h
,
i
]
*
BS
if
i_s
<=
i_t
and
i_s
>=
0
:
# [BS, BK]
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
G
,
BS
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Softmax
T
.
copy
(
scores_max
,
scores_max_prev
)
...
...
@@ -564,45 +561,33 @@ def tilelang_sparse_attention(batch,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
# V * softmax(Q * K)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
G
,
BV
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
O_shared
,
Output
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
tilelang_sparse_attention
def
generate_block_indices
(
batch
,
seq_len
,
heads
,
selected_blocks
,
block_size
):
"""Generate random block indices for the benchmark."""
block_indices
=
torch
.
full
((
batch
,
seq_len
,
heads
,
selected_blocks
),
seq_len
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
seq_len
,
heads
,
selected_blocks
),
seq_len
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
b
in
range
(
batch
):
for
t
in
range
(
seq_len
):
for
h
in
range
(
heads
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
selected_blocks
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
return
block_indices
.
sort
(
-
1
)[
0
]
def
benchmark_nsa
(
batch_size
,
seq_len
,
heads
,
head_query
,
dim
,
selected_blocks
,
block_size
,
dtype
,
scale
,
warmup
=
10
,
iterations
=
100
,
validate
=
False
):
def
benchmark_nsa
(
batch_size
,
seq_len
,
heads
,
head_query
,
dim
,
selected_blocks
,
block_size
,
dtype
,
scale
,
warmup
=
10
,
iterations
=
100
,
validate
=
False
):
"""Benchmark the TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
...
...
@@ -628,14 +613,13 @@ def benchmark_nsa(batch_size,
print
(
f
"Profiler latency:
{
profiler_latency
}
ms"
)
# Create input tensors
Q
=
torch
.
randn
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
out
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
out
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
# Generate block indices
block_indices
=
generate_block_indices
(
batch_size
,
seq_len
,
heads
,
selected_blocks
,
block_size
).
to
(
torch
.
int32
)
block_indices
=
generate_block_indices
(
batch_size
,
seq_len
,
heads
,
selected_blocks
,
block_size
).
to
(
torch
.
int32
)
# Warmup
for
_
in
range
(
warmup
):
...
...
@@ -666,10 +650,9 @@ def benchmark_nsa(batch_size,
# Validate result against reference if requested
if
validate
:
g_slc
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
'cuda'
)
g_swa
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
'cuda'
)
block_counts
=
torch
.
randint
(
1
,
selected_blocks
+
1
,
(
batch_size
,
seq_len
,
heads
),
device
=
'cuda'
)
g_slc
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
"cuda"
)
g_swa
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
"cuda"
)
block_counts
=
torch
.
randint
(
1
,
selected_blocks
+
1
,
(
batch_size
,
seq_len
,
heads
),
device
=
"cuda"
)
ref
=
naive_nsa
(
q
=
Q
,
...
...
@@ -700,22 +683,13 @@ def benchmark_nsa(batch_size,
"head_query"
:
head_query
,
"dim"
:
dim
,
"selected_blocks"
:
selected_blocks
,
"block_size"
:
block_size
"block_size"
:
block_size
,
}
def
benchmark_triton_nsa
(
batch_size
,
seq_len
,
heads
,
head_query
,
dim
,
selected_blocks
,
block_size
,
dtype
,
scale
,
warmup
=
10
,
iterations
=
100
,
validate
=
False
):
def
benchmark_triton_nsa
(
batch_size
,
seq_len
,
heads
,
head_query
,
dim
,
selected_blocks
,
block_size
,
dtype
,
scale
,
warmup
=
10
,
iterations
=
100
,
validate
=
False
):
"""Benchmark the Triton-based TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
...
...
@@ -723,18 +697,17 @@ def benchmark_triton_nsa(batch_size,
torch
.
random
.
manual_seed
(
0
)
# Create input tensors
Q
=
torch
.
randn
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
g_slc
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
'
cuda
'
)
g_swa
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch_size
,
seq_len
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
g_slc
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
"
cuda
"
)
g_swa
=
torch
.
ones
((
batch_size
,
seq_len
,
head_query
),
dtype
=
dtype
,
device
=
"
cuda
"
)
# Generate block indices
block_indices
=
generate_block_indices
(
batch_size
,
seq_len
,
heads
,
selected_blocks
,
block_size
)
block_counts
=
torch
.
randint
(
1
,
selected_blocks
+
1
,
(
batch_size
,
seq_len
,
heads
),
device
=
'cuda'
)
o_slc
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
lse_slc
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
),
dtype
=
torch
.
float
,
device
=
'cuda'
)
block_counts
=
torch
.
randint
(
1
,
selected_blocks
+
1
,
(
batch_size
,
seq_len
,
heads
),
device
=
"cuda"
)
o_slc
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
lse_slc
=
torch
.
empty
((
batch_size
,
seq_len
,
head_query
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
warmup
):
...
...
@@ -750,7 +723,8 @@ def benchmark_triton_nsa(batch_size,
block_counts
=
block_counts
,
block_size
=
block_size
,
window_size
=
0
,
scale
=
scale
)
scale
=
scale
,
)
# Synchronize before timing
torch
.
cuda
.
synchronize
()
...
...
@@ -770,7 +744,8 @@ def benchmark_triton_nsa(batch_size,
block_counts
=
block_counts
,
block_size
=
block_size
,
window_size
=
0
,
scale
=
scale
)
scale
=
scale
,
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
...
...
@@ -815,54 +790,28 @@ def benchmark_triton_nsa(batch_size,
"head_query"
:
head_query
,
"dim"
:
dim
,
"selected_blocks"
:
selected_blocks
,
"block_size"
:
block_size
"block_size"
:
block_size
,
}
def
run_benchmark_suite
(
impl
=
'
all
'
):
def
run_benchmark_suite
(
impl
=
"
all
"
):
"""Run a suite of benchmarks with different configurations."""
# Define configurations to benchmark
configs
=
[
# Small model config - Note: head_query must be a multiple of heads*16 for Triton
{
"batch_size"
:
2
,
"seq_len"
:
1024
,
"heads"
:
8
,
"head_query"
:
8
*
16
,
"dim"
:
64
,
"selected_blocks"
:
8
,
"block_size"
:
32
},
{
"batch_size"
:
2
,
"seq_len"
:
1024
,
"heads"
:
8
,
"head_query"
:
8
*
16
,
"dim"
:
64
,
"selected_blocks"
:
8
,
"block_size"
:
32
},
# Medium model config
{
"batch_size"
:
2
,
"seq_len"
:
2048
,
"heads"
:
16
,
"head_query"
:
16
*
16
,
"dim"
:
64
,
"selected_blocks"
:
16
,
"block_size"
:
64
},
{
"batch_size"
:
2
,
"seq_len"
:
2048
,
"heads"
:
16
,
"head_query"
:
16
*
16
,
"dim"
:
64
,
"selected_blocks"
:
16
,
"block_size"
:
64
},
# Large model config
{
"batch_size"
:
1
,
"seq_len"
:
4096
,
"heads"
:
32
,
"head_query"
:
32
*
16
,
"dim"
:
128
,
"selected_blocks"
:
32
,
"block_size"
:
128
},
{
"batch_size"
:
1
,
"seq_len"
:
4096
,
"heads"
:
32
,
"head_query"
:
32
*
16
,
"dim"
:
128
,
"selected_blocks"
:
32
,
"block_size"
:
128
},
]
results
=
[]
for
config
in
configs
:
print
(
f
"Running benchmark with config:
{
config
}
"
)
if
impl
in
[
'
all
'
,
'
tilelang
'
]:
if
impl
in
[
"
all
"
,
"
tilelang
"
]:
print
(
"Benchmarking TileLang implementation:"
)
result
=
benchmark_nsa
(
batch_size
=
config
[
"batch_size"
],
...
...
@@ -874,12 +823,13 @@ def run_benchmark_suite(impl='all'):
block_size
=
config
[
"block_size"
],
dtype
=
torch
.
float16
,
scale
=
0.1
,
validate
=
False
)
validate
=
False
,
)
results
.
append
({
"impl"
:
"tilelang"
,
**
result
})
print
(
f
"Average time:
{
result
[
'avg_time_ms'
]:.
2
f
}
ms"
)
print
(
f
"Performance:
{
result
[
'tflops'
]:.
2
f
}
TFLOPs"
)
if
impl
in
[
'
all
'
,
'
triton
'
]:
if
impl
in
[
"
all
"
,
"
triton
"
]:
print
(
"Benchmarking Triton implementation:"
)
result
=
benchmark_triton_nsa
(
batch_size
=
config
[
"batch_size"
],
...
...
@@ -891,19 +841,24 @@ def run_benchmark_suite(impl='all'):
block_size
=
config
[
"block_size"
],
dtype
=
torch
.
float16
,
scale
=
0.1
,
validate
=
False
)
validate
=
False
,
)
results
.
append
({
"impl"
:
"triton"
,
**
result
})
print
(
f
"Average time:
{
result
[
'avg_time_ms'
]:.
2
f
}
ms"
)
print
(
f
"Performance:
{
result
[
'tflops'
]:.
2
f
}
TFLOPs"
)
if
impl
in
[
'
all
'
]:
if
impl
in
[
"
all
"
]:
# Print comparison if both implementations were run
tilelang_result
=
next
(
r
for
r
in
results
if
r
[
"impl"
]
==
"tilelang"
and
r
[
"batch_size"
]
==
config
[
"batch_size"
]
and
r
[
"seq_len"
]
==
config
[
"seq_len"
])
r
for
r
in
results
if
r
[
"impl"
]
==
"tilelang"
and
r
[
"batch_size"
]
==
config
[
"batch_size"
]
and
r
[
"seq_len"
]
==
config
[
"seq_len"
]
)
triton_result
=
next
(
r
for
r
in
results
if
r
[
"impl"
]
==
"triton"
and
r
[
"batch_size"
]
==
config
[
"batch_size"
]
and
r
[
"seq_len"
]
==
config
[
"seq_len"
])
r
for
r
in
results
if
r
[
"impl"
]
==
"triton"
and
r
[
"batch_size"
]
==
config
[
"batch_size"
]
and
r
[
"seq_len"
]
==
config
[
"seq_len"
]
)
speedup
=
tilelang_result
[
"avg_time_ms"
]
/
triton_result
[
"avg_time_ms"
]
print
(
f
"Speedup (Triton vs TileLang):
{
speedup
:.
2
f
}
x"
)
...
...
@@ -921,8 +876,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension"
)
parser
.
add_argument
(
"--selected_blocks"
,
type
=
int
,
default
=
16
,
help
=
"Number of selected blocks"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"Block size"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"Data type (float16 or float32)"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"Data type (float16 or float32)"
)
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
0.1
,
help
=
"Attention scale factor"
)
parser
.
add_argument
(
"--iterations"
,
type
=
int
,
default
=
100
,
help
=
"Number of iterations"
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
10
,
help
=
"Warmup iterations"
)
...
...
@@ -933,7 +887,8 @@ if __name__ == "__main__":
type
=
str
,
default
=
"all"
,
choices
=
[
"tilelang"
,
"triton"
,
"all"
],
help
=
"Implementation to benchmark (tilelang, triton, or all)"
)
help
=
"Implementation to benchmark (tilelang, triton, or all)"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -941,8 +896,7 @@ if __name__ == "__main__":
if
args
.
impl
in
[
"triton"
,
"all"
]
and
args
.
head_query
%
(
args
.
heads
*
16
)
!=
0
:
# Adjust head_query to nearest valid value
args
.
head_query
=
((
args
.
head_query
//
(
args
.
heads
*
16
))
+
1
)
*
(
args
.
heads
*
16
)
print
(
f
"Adjusted head_query to
{
args
.
head_query
}
to be compatible with Triton implementation"
)
print
(
f
"Adjusted head_query to
{
args
.
head_query
}
to be compatible with Triton implementation"
)
if
args
.
suite
:
run_benchmark_suite
(
impl
=
args
.
impl
)
...
...
@@ -963,12 +917,14 @@ if __name__ == "__main__":
scale
=
args
.
scale
,
warmup
=
args
.
warmup
,
iterations
=
args
.
iterations
,
validate
=
args
.
validate
)
validate
=
args
.
validate
,
)
print
(
"
\n
Benchmark Results (TileLang):"
)
print
(
f
"Configuration: batch=
{
args
.
batch
}
, seq_len=
{
args
.
seq_len
}
, heads=
{
args
.
heads
}
, "
+
f
"head_query=
{
args
.
head_query
}
, dim=
{
args
.
dim
}
, blocks=
{
args
.
selected_blocks
}
, "
+
f
"block_size=
{
args
.
block_size
}
"
)
f
"Configuration: batch=
{
args
.
batch
}
, seq_len=
{
args
.
seq_len
}
, heads=
{
args
.
heads
}
, "
+
f
"head_query=
{
args
.
head_query
}
, dim=
{
args
.
dim
}
, blocks=
{
args
.
selected_blocks
}
, "
+
f
"block_size=
{
args
.
block_size
}
"
)
print
(
f
"Average time:
{
result
[
'avg_time_ms'
]:.
2
f
}
ms"
)
print
(
f
"Performance:
{
result
[
'tflops'
]:.
2
f
}
TFLOPs"
)
...
...
@@ -986,11 +942,13 @@ if __name__ == "__main__":
scale
=
args
.
scale
,
warmup
=
args
.
warmup
,
iterations
=
args
.
iterations
,
validate
=
args
.
validate
)
validate
=
args
.
validate
,
)
print
(
"
\n
Benchmark Results (Triton):"
)
print
(
f
"Configuration: batch=
{
args
.
batch
}
, seq_len=
{
args
.
seq_len
}
, heads=
{
args
.
heads
}
, "
+
f
"head_query=
{
args
.
head_query
}
, dim=
{
args
.
dim
}
, blocks=
{
args
.
selected_blocks
}
, "
+
f
"block_size=
{
args
.
block_size
}
"
)
f
"Configuration: batch=
{
args
.
batch
}
, seq_len=
{
args
.
seq_len
}
, heads=
{
args
.
heads
}
, "
+
f
"head_query=
{
args
.
head_query
}
, dim=
{
args
.
dim
}
, blocks=
{
args
.
selected_blocks
}
, "
+
f
"block_size=
{
args
.
block_size
}
"
)
print
(
f
"Average time:
{
result
[
'avg_time_ms'
]:.
2
f
}
ms"
)
print
(
f
"Performance:
{
result
[
'tflops'
]:.
2
f
}
TFLOPs"
)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
View file @
29051439
...
...
@@ -7,6 +7,7 @@ import torch
import
triton
import
fla
if
parse
(
fla
.
__version__
)
<
parse
(
"0.2.1"
):
from
fla.ops.common.utils
import
prepare_token_indices
else
:
...
...
@@ -22,7 +23,8 @@ import tilelang
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
}
)
def
tilelang_kernel_fwd
(
batch
,
heads
,
...
...
@@ -34,11 +36,10 @@ def tilelang_kernel_fwd(
groups
=
1
,
selected_blocks
=
16
,
):
from
tilelang
import
language
as
T
if
scale
is
None
:
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
else
:
scale
=
scale
*
1.44269504
# log2(e)
...
...
@@ -67,12 +68,12 @@ def tilelang_kernel_fwd(
@
T
.
prim_func
def
native_sparse_attention
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
O_slc
:
T
.
Tensor
(
o_slc_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
O_slc
:
T
.
Tensor
(
o_slc_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
seq_len
,
NV
,
batch
*
head_kv
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
G
,
BK
],
dtype
)
...
...
@@ -93,7 +94,7 @@ def tilelang_kernel_fwd(
i_b
,
i_h
=
i_bh
//
head_kv
,
i_bh
%
head_kv
NS
=
S
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
...
...
@@ -103,12 +104,11 @@ def tilelang_kernel_fwd(
i_s
=
BlockIndices
[
i_b
,
i_t
,
i_h
,
i
]
*
BS
if
i_s
<=
i_t
and
i_s
>=
0
:
# [BS, BK]
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
if
is_causal
:
for
k
,
j
in
T
.
Parallel
(
G
,
BS
):
acc_s
[
k
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
k
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
...
...
@@ -138,7 +138,7 @@ def tilelang_kernel_fwd(
acc_o
[
k
,
j
]
*=
scores_scale
[
k
]
# V * softmax(Q * K)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
G
,
BV
):
...
...
@@ -146,18 +146,20 @@ def tilelang_kernel_fwd(
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
O_slc
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
O_slc
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
)
for
i
in
T
.
Parallel
(
G
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
LSE_slc
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
])
T
.
copy
(
logsum
,
LSE_slc
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
])
return
native_sparse_attention
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
tilelang_kernel_bwd_dkv
(
batch
,
heads
,
...
...
@@ -172,7 +174,7 @@ def tilelang_kernel_bwd_dkv(
accum_dtype
=
"float"
,
):
if
scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
else
:
sm_scale
=
scale
...
...
@@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv(
@
T
.
prim_func
def
flash_bwd_dkv
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
k_shape
,
dtype
),
V
:
T
.
Tensor
(
v_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
Delta_slc
:
T
.
Tensor
(
delta_slc_shape
,
accum_dtype
),
DO_slc
:
T
.
Tensor
(
do_slc_shape
,
dtype
),
DK
:
T
.
Tensor
(
dk_shape
,
dtype
),
DV
:
T
.
Tensor
(
dv_shape
,
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"int32"
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
k_shape
,
dtype
),
V
:
T
.
Tensor
(
v_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
Delta_slc
:
T
.
Tensor
(
delta_slc_shape
,
accum_dtype
),
DO_slc
:
T
.
Tensor
(
do_slc_shape
,
dtype
),
DK
:
T
.
Tensor
(
dk_shape
,
dtype
),
DV
:
T
.
Tensor
(
dv_shape
,
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"int32"
),
):
with
T
.
Kernel
(
NV
,
NS
,
B
*
H
,
threads
=
num_threads
)
as
(
i_v
,
i_s
,
i_bh
):
K_shared
=
T
.
alloc_shared
([
BS
,
BK
],
dtype
)
...
...
@@ -238,31 +240,33 @@ def tilelang_kernel_bwd_dkv(
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
T
.
copy
(
K
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
],
K_shared
)
T
.
copy
(
V
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
],
V_shared
)
T
.
copy
(
K
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
],
K_shared
)
T
.
copy
(
V
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
],
V_shared
)
# [BS, BK]
T
.
clear
(
dk
)
# [BS, BV]
T
.
clear
(
dv
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
}
)
loop_st
=
i_s
*
BS
loop_ed
=
seq_len
for
i
in
T
.
Pipelined
(
start
=
loop_st
,
stop
=
loop_ed
,
num_stages
=
0
,
start
=
loop_st
,
stop
=
loop_ed
,
num_stages
=
0
,
):
b_m_slc
=
BlockMask
[
i_b
,
i
,
i_h
,
i_s
]
if
b_m_slc
!=
0
:
# [G, BK]
T
.
copy
(
Q
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
copy
(
Q
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
clear
(
qkT
)
# [BS, BK] @ [G, BK] -> [BS, G]
T
.
gemm
(
...
...
@@ -273,7 +277,7 @@ def tilelang_kernel_bwd_dkv(
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
)
# [G]
T
.
copy
(
LSE_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
lse_shared
)
T
.
copy
(
LSE_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
lse_shared
)
for
_i
,
_j
in
T
.
Parallel
(
BS
,
G
):
qkT
[
_i
,
_j
]
=
T
.
exp2
(
qkT
[
_i
,
_j
]
*
scale
-
lse_shared
[
_j
])
...
...
@@ -282,7 +286,7 @@ def tilelang_kernel_bwd_dkv(
qkT
[
_i
,
_j
]
=
T
.
if_then_else
(
i
>=
(
i_s
*
BS
+
_i
),
qkT
[
_i
,
_j
],
0
)
# [G, BV]
T
.
copy
(
DO_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BV
],
do
)
T
.
copy
(
DO_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BV
],
do
)
T
.
clear
(
dsT
)
# [BS, BV] @ [G, BV] -> [BS, G]
T
.
gemm
(
...
...
@@ -296,7 +300,7 @@ def tilelang_kernel_bwd_dkv(
# [BS, G] @ [G, BV] -> [BS, BV]
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# [G]
T
.
copy
(
Delta_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
delta
)
T
.
copy
(
Delta_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
delta
)
for
i
,
j
in
T
.
Parallel
(
BS
,
G
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -305,8 +309,8 @@ def tilelang_kernel_bwd_dkv(
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
DV
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
])
T
.
copy
(
dk_shared
,
DK
[
i_v
,
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
])
T
.
copy
(
dv_shared
,
DV
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
])
T
.
copy
(
dk_shared
,
DK
[
i_v
,
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
])
return
flash_bwd_dkv
...
...
@@ -321,9 +325,11 @@ def make_dq_layout(dQ):
)
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
tilelang_kernel_bwd_dqkv
(
batch
,
heads
,
...
...
@@ -338,7 +344,7 @@ def tilelang_kernel_bwd_dqkv(
accum_dtype
=
"float"
,
):
if
scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
else
:
sm_scale
=
scale
...
...
@@ -373,16 +379,16 @@ def tilelang_kernel_bwd_dqkv(
@
T
.
prim_func
def
flash_bwd_dqkv
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
k_shape
,
dtype
),
V
:
T
.
Tensor
(
v_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
Delta_slc
:
T
.
Tensor
(
delta_slc_shape
,
accum_dtype
),
DO_slc
:
T
.
Tensor
(
do_slc_shape
,
dtype
),
DQ
:
T
.
Tensor
(
dq_shape
,
dtype
),
DK
:
T
.
Tensor
(
dk_shape
,
dtype
),
DV
:
T
.
Tensor
(
dv_shape
,
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"int32"
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
k_shape
,
dtype
),
V
:
T
.
Tensor
(
v_shape
,
dtype
),
LSE_slc
:
T
.
Tensor
(
lse_slc_shape
,
accum_dtype
),
Delta_slc
:
T
.
Tensor
(
delta_slc_shape
,
accum_dtype
),
DO_slc
:
T
.
Tensor
(
do_slc_shape
,
dtype
),
DQ
:
T
.
Tensor
(
dq_shape
,
dtype
),
DK
:
T
.
Tensor
(
dk_shape
,
dtype
),
DV
:
T
.
Tensor
(
dv_shape
,
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"int32"
),
):
with
T
.
Kernel
(
NV
,
NS
,
B
*
H
,
threads
=
num_threads
)
as
(
i_v
,
i_s
,
i_bh
):
K_shared
=
T
.
alloc_shared
([
BS
,
BK
],
dtype
)
...
...
@@ -406,31 +412,33 @@ def tilelang_kernel_bwd_dqkv(
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
T
.
copy
(
K
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
],
K_shared
)
T
.
copy
(
V
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
],
V_shared
)
T
.
copy
(
K
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
],
K_shared
)
T
.
copy
(
V
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
],
V_shared
)
# [BS, BK]
T
.
clear
(
dk
)
# [BS, BV]
T
.
clear
(
dv
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
}
)
loop_st
=
i_s
*
BS
loop_ed
=
seq_len
for
i
in
T
.
Pipelined
(
start
=
loop_st
,
stop
=
loop_ed
,
num_stages
=
0
,
start
=
loop_st
,
stop
=
loop_ed
,
num_stages
=
0
,
):
b_m_slc
=
BlockMask
[
i_b
,
i
,
i_h
,
i_s
]
if
b_m_slc
!=
0
:
# [G, BK]
T
.
copy
(
Q
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
copy
(
Q
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
clear
(
qkT
)
# [BS, BK] @ [G, BK] -> [BS, G]
T
.
gemm
(
...
...
@@ -441,7 +449,7 @@ def tilelang_kernel_bwd_dqkv(
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
)
# [G]
T
.
copy
(
LSE_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
lse_shared
)
T
.
copy
(
LSE_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
lse_shared
)
for
_i
,
_j
in
T
.
Parallel
(
BS
,
G
):
qkT
[
_i
,
_j
]
=
T
.
exp2
(
qkT
[
_i
,
_j
]
*
scale
-
lse_shared
[
_j
])
...
...
@@ -450,7 +458,7 @@ def tilelang_kernel_bwd_dqkv(
qkT
[
_i
,
_j
]
=
T
.
if_then_else
(
i
>=
(
i_s
*
BS
+
_i
),
qkT
[
_i
,
_j
],
0
)
# [G, BV]
T
.
copy
(
DO_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BV
],
do
)
T
.
copy
(
DO_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BV
],
do
)
T
.
clear
(
dsT
)
# [BS, BV] @ [G, BV] -> [BS, G]
T
.
gemm
(
...
...
@@ -464,7 +472,7 @@ def tilelang_kernel_bwd_dqkv(
# [BS, G] @ [G, BV] -> [BS, BV]
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# [G]
T
.
copy
(
Delta_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
delta
)
T
.
copy
(
Delta_slc
[
i_b
,
i
,
i_h
*
G
:
(
i_h
+
1
)
*
G
],
delta
)
for
_i
,
_j
in
T
.
Parallel
(
BS
,
G
):
dsT_cast
[
_i
,
_j
]
=
qkT
[
_i
,
_j
]
*
(
dsT
[
_i
,
_j
]
-
delta
[
_j
])
*
sm_scale
...
...
@@ -480,16 +488,18 @@ def tilelang_kernel_bwd_dqkv(
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
DV
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
])
T
.
copy
(
dk_shared
,
DK
[
i_v
,
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
])
T
.
copy
(
dv_shared
,
DV
[
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BV
])
T
.
copy
(
dk_shared
,
DK
[
i_v
,
i_b
,
i_s
*
BS
:
(
i_s
+
1
)
*
BS
,
i_h
,
:
BK
])
return
flash_bwd_dqkv
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
tilelang_kernel_preprocess
(
batch
,
heads
,
...
...
@@ -505,9 +515,9 @@ def tilelang_kernel_preprocess(
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
seq_len
,
heads
],
accum_dtype
),
# type: ignore
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
seq_len
,
heads
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
...
...
@@ -516,20 +526,22 @@ def tilelang_kernel_preprocess(
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
])
T
.
copy
(
delta
,
Delta
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
])
return
flash_bwd_prep
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
tilelang_kernel_block_mask
(
batch
,
heads
,
...
...
@@ -551,9 +563,9 @@ def tilelang_kernel_block_mask(
@
T
.
prim_func
def
flash_bwd_block_mask
(
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
dtype
),
# type: ignore
BlockCounts
:
T
.
Tensor
(
block_counts_shape
,
dtype
),
# type: ignore
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
dtype
),
# type: ignore
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
dtype
),
# type: ignore
BlockCounts
:
T
.
Tensor
(
block_counts_shape
,
dtype
),
# type: ignore
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
seq_len
,
batch
,
heads
*
S
)
as
(
bx
,
by
,
bz
):
i_t
,
i_b
,
i_hs
=
bx
,
by
,
bz
...
...
@@ -603,9 +615,7 @@ def parallel_nsa_bwd(
dk
=
torch
.
empty
(
NV
,
*
k
.
shape
,
dtype
=
k
.
dtype
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
v
.
shape
,
dtype
=
v
.
dtype
,
device
=
q
.
device
)
block_mask
=
tilelang_kernel_block_mask
(
B
,
H
,
T
,
S
,
BS
)(
block_indices
.
to
(
torch
.
int32
),
block_counts
.
to
(
torch
.
int32
)).
to
(
torch
.
bool
)
block_mask
=
tilelang_kernel_block_mask
(
B
,
H
,
T
,
S
,
BS
)(
block_indices
.
to
(
torch
.
int32
),
block_counts
.
to
(
torch
.
int32
)).
to
(
torch
.
bool
)
fused_qkv_bwd_kernel
=
tilelang_kernel_bwd_dqkv
(
batch
=
B
,
...
...
@@ -618,8 +628,7 @@ def parallel_nsa_bwd(
selected_blocks
=
S
,
scale
=
scale
,
)
fused_qkv_bwd_kernel
(
q
,
k
,
v
,
lse_slc
,
delta_slc
,
do_slc
,
dq
,
dk
,
dv
,
block_mask
.
to
(
torch
.
int32
))
fused_qkv_bwd_kernel
(
q
,
k
,
v
,
lse_slc
,
delta_slc
,
do_slc
,
dq
,
dk
,
dv
,
block_mask
.
to
(
torch
.
int32
))
dq
=
dq
.
sum
(
0
)
dk
=
dk
.
sum
(
0
)
...
...
@@ -628,7 +637,6 @@ def parallel_nsa_bwd(
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -773,23 +781,21 @@ def parallel_nsa(
Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
"b h t -> b t h"
)
assert
(
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
)
,
"Group size must be a multiple of 16 in NSA"
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
o
=
torch
.
addcmul
(
o_slc
*
g_slc
.
unsqueeze
(
-
1
),
o_swa
,
g_swa
.
unsqueeze
(
-
1
))
else
:
...
...
@@ -814,7 +820,7 @@ if __name__ == "__main__":
for
t
in
range
(
T
):
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
S
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
T
,
H
),
device
=
"cuda"
)
...
...
examples/deepseek_nsa/example_tilelang_nsa_decode.py
View file @
29051439
...
...
@@ -16,7 +16,8 @@ tilelang.testing.set_random_seed(42)
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
native_sparse_attention
(
batch
,
heads
,
...
...
@@ -25,10 +26,10 @@ def native_sparse_attention(
scale
=
None
,
block_size
=
64
,
# Tile size for attention computation
groups
=
1
,
# Grouped query attention (GQA) groups
selected_blocks
=
16
# Number of blocks to select per attention head
selected_blocks
=
16
,
# Number of blocks to select per attention head
):
if
scale
is
None
:
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
# Modified shapes for inference (q has seq_len=1)a
q_shape
=
[
batch
,
1
,
heads
,
dim
]
# Changed seq_len to 1
...
...
@@ -53,12 +54,11 @@ def native_sparse_attention(
@
T
.
prim_func
def
native_sparse_attention
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# [batch, 1, heads, dim]
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
# [batch, seq_len, head_kv, dim]
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
# Same shape as K
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
# Selected block indices
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
# Output attention tensor
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# [batch, 1, heads, dim]
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
# [batch, seq_len, head_kv, dim]
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
# Same shape as K
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
# Selected block indices
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
# Output attention tensor
):
with
T
.
Kernel
(
1
,
NV
,
batch
*
head_kv
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
# Shared memory allocations for tile storage
...
...
@@ -82,7 +82,7 @@ def native_sparse_attention(
NS
=
S
# Copy Q for the single position
T
.
copy
(
Q
[
i_b
,
0
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
# Changed i_t to 0
T
.
copy
(
Q
[
i_b
,
0
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
# Changed i_t to 0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
...
...
@@ -93,16 +93,11 @@ def native_sparse_attention(
i_s
=
BlockIndices
[
i_b
,
0
,
i_h
,
i
]
*
BS
# Get block offset
if
i_s
>=
0
:
# Skip invalid/padding blocks
# Load current key block to shared memory
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
# Compute QK^T attention scores
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Online softmax with numerical stability
# 1. Compute max for scaling
...
...
@@ -122,15 +117,14 @@ def native_sparse_attention(
T
.
copy
(
acc_s
,
acc_s_cast
)
# Accumulate attention-weighted values
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Final normalization and output
for
i
,
j
in
T
.
Parallel
(
G
,
BV
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
# Normalize by logsum
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
i_b
,
0
,
i_h
*
G
:(
i_h
+
1
)
*
G
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
])
# Changed i_t to 0
T
.
copy
(
O_shared
,
Output
[
i_b
,
0
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
# Changed i_t to 0
return
native_sparse_attention
...
...
@@ -149,21 +143,21 @@ def main():
selected_blocks
=
S
,
)
Q
=
torch
.
randn
((
B
,
SEQ_LEN_Q
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
K
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
V
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
Q
=
torch
.
randn
((
B
,
SEQ_LEN_Q
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
K
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
V
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
mask
=
torch
.
randint
(
0
,
2
,
(
B
,
SEQ_LEN
,
groups
),
device
=
'
cuda
'
)
DO
=
torch
.
randn
((
B
,
SEQ_LEN_Q
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
)
mask
=
torch
.
randint
(
0
,
2
,
(
B
,
SEQ_LEN
,
groups
),
device
=
"
cuda
"
)
DO
=
torch
.
randn
((
B
,
SEQ_LEN_Q
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
)
block_indices
=
torch
.
full
((
B
,
SEQ_LEN_Q
,
H
,
S
),
SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
B
,
SEQ_LEN_Q
,
H
,
S
),
SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
b
in
range
(
B
):
for
t
in
range
(
SEQ_LEN_Q
):
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
S
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
SEQ_LEN_Q
,
H
),
device
=
'
cuda
'
)
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
SEQ_LEN_Q
,
H
),
device
=
"
cuda
"
)
out
=
kernel
(
Q
,
K
,
V
,
block_indices
.
to
(
torch
.
int32
))
...
...
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
View file @
29051439
...
...
@@ -14,18 +14,11 @@ tilelang.testing.set_random_seed(0)
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
native_sparse_attention
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
):
},
)
def
native_sparse_attention
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
):
if
scale
is
None
:
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
else
:
scale
=
scale
*
1.44269504
# log2(e)
...
...
@@ -52,11 +45,11 @@ def native_sparse_attention(batch,
@
T
.
prim_func
def
native_sparse_attention
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
seq_len
,
NV
,
batch
*
head_kv
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
G
,
BK
],
dtype
)
...
...
@@ -77,7 +70,7 @@ def native_sparse_attention(batch,
i_b
,
i_h
=
i_bh
//
head_kv
,
i_bh
%
head_kv
NS
=
S
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
copy
(
Q
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
...
...
@@ -87,21 +80,15 @@ def native_sparse_attention(batch,
i_s
=
BlockIndices
[
i_b
,
i_t
,
i_h
,
i
]
*
BS
if
i_s
<=
i_t
and
i_s
>=
0
:
# [BS, BK]
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
T
.
copy
(
K
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
G
,
BS
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Softmax
T
.
copy
(
scores_max
,
scores_max_prev
)
...
...
@@ -121,13 +108,13 @@ def native_sparse_attention(batch,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
# V * softmax(Q * K)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
copy
(
V
[
i_b
,
i_s
:
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
G
,
BV
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
O_shared
,
Output
[
i_b
,
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
native_sparse_attention
...
...
@@ -148,20 +135,20 @@ def main():
)
print
(
kernel
.
get_kernel_source
())
torch
.
random
.
manual_seed
(
0
)
Q
=
torch
.
randn
((
B
,
SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
K
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
V
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
DO
=
torch
.
randn
((
B
,
SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
B
,
SEQ_LEN
,
H
,
S
),
SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
block_counts
=
torch
.
zeros
((
B
,
SEQ_LEN
,
H
),
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
B
,
SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
K
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
V
=
torch
.
randn
((
B
,
SEQ_LEN
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
DO
=
torch
.
randn
((
B
,
SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
)
block_indices
=
torch
.
full
((
B
,
SEQ_LEN
,
H
,
S
),
SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
block_counts
=
torch
.
zeros
((
B
,
SEQ_LEN
,
H
),
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
b
in
range
(
B
):
for
t
in
range
(
SEQ_LEN
):
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
S
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_counts
[
b
,
t
,
h
]
=
(
block_indices
[
b
,
t
,
h
]
!=
SEQ_LEN
).
sum
().
item
()
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
...
...
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
View file @
29051439
...
...
@@ -8,6 +8,7 @@ from tilelang import language as T
import
tilelang.testing
import
fla
if
parse
(
fla
.
__version__
)
<
parse
(
"0.2.1"
):
from
fla.ops.common.utils
import
prepare_token_indices
else
:
...
...
@@ -21,18 +22,11 @@ from einops import rearrange
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
native_sparse_attention_varlen
(
batch
,
heads
,
c_seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
):
}
)
def
native_sparse_attention_varlen
(
batch
,
heads
,
c_seq_len
,
dim
,
is_causal
,
scale
=
None
,
block_size
=
64
,
groups
=
1
,
selected_blocks
=
16
):
if
scale
is
None
:
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
c_seq_len
,
heads
,
dim
]
kv_shape
=
[
c_seq_len
,
head_kv
,
dim
]
...
...
@@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch,
@
T
.
prim_func
def
native_sparse_attention_varlen
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
O_slc
:
T
.
Tensor
(
o_slc_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
BlockCounts
:
T
.
Tensor
(
block_counts_shape
,
block_counts_dtype
),
Offsets
:
T
.
Tensor
(
offsets_shape
,
offsets_dtype
),
TokenIndices
:
T
.
Tensor
(
token_indices_shape
,
token_indices_dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
O_slc
:
T
.
Tensor
(
o_slc_shape
,
dtype
),
BlockIndices
:
T
.
Tensor
(
block_indices_shape
,
block_indices_dtype
),
BlockCounts
:
T
.
Tensor
(
block_counts_shape
,
block_counts_dtype
),
Offsets
:
T
.
Tensor
(
offsets_shape
,
offsets_dtype
),
TokenIndices
:
T
.
Tensor
(
token_indices_shape
,
token_indices_dtype
),
):
with
T
.
Kernel
(
c_seq_len
,
NV
,
batch
*
head_kv
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
G
,
BK
],
dtype
)
...
...
@@ -100,7 +94,7 @@ def native_sparse_attention_varlen(batch,
current_seq_len
=
eos
-
bos
NS
=
BlockCounts
[
i_t
,
i_h
]
T
.
copy
(
Q
[
bos
+
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
copy
(
Q
[
bos
+
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
:
BK
],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
...
...
@@ -112,21 +106,15 @@ def native_sparse_attention_varlen(batch,
# [BS, BK]
# Lei: may have some padding issues
# we should learn from mha varlen templates to handle this
T
.
copy
(
K
[
bos
+
i_s
:
bos
+
i_s
+
BS
,
i_h
,
:
BK
],
K_shared
)
T
.
copy
(
K
[
bos
+
i_s
:
bos
+
i_s
+
BS
,
i_h
,
:
BK
],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
G
,
BS
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_t
>=
(
i_s
+
j
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Softmax
T
.
copy
(
scores_max
,
scores_max_prev
)
...
...
@@ -146,13 +134,13 @@ def native_sparse_attention_varlen(batch,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
# V * softmax(Q * K)
T
.
copy
(
V
[
bos
+
i_s
:
bos
+
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
copy
(
V
[
bos
+
i_s
:
bos
+
i_s
+
BS
,
i_h
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
G
,
BV
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
O_slc
[
bos
+
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
T
.
copy
(
O_shared
,
O_slc
[
bos
+
i_t
,
i_h
*
G
:
(
i_h
+
1
)
*
G
,
i_v
*
BV
:
(
i_v
+
1
)
*
BV
])
return
native_sparse_attention_varlen
...
...
@@ -190,17 +178,20 @@ def parallel_nsa_fwd(
o_slc
=
torch
.
empty
(
B
,
C_SEQ_LEN
,
HQ
,
V
,
dtype
=
v
.
dtype
,
device
=
q
.
device
)
kernel
(
q
.
view
(
C_SEQ_LEN
,
HQ
,
D
),
k
.
view
(
C_SEQ_LEN
,
H
,
D
),
v
.
view
(
C_SEQ_LEN
,
H
,
D
),
q
.
view
(
C_SEQ_LEN
,
HQ
,
D
),
k
.
view
(
C_SEQ_LEN
,
H
,
D
),
v
.
view
(
C_SEQ_LEN
,
H
,
D
),
o_slc
.
view
(
C_SEQ_LEN
,
HQ
,
V
),
block_indices
.
to
(
torch
.
int32
).
view
(
C_SEQ_LEN
,
H
,
S
),
block_counts
.
to
(
torch
.
int32
).
view
(
C_SEQ_LEN
,
H
),
offsets
.
to
(
torch
.
int32
),
token_indices
.
to
(
torch
.
int32
))
block_counts
.
to
(
torch
.
int32
).
view
(
C_SEQ_LEN
,
H
),
offsets
.
to
(
torch
.
int32
),
token_indices
.
to
(
torch
.
int32
),
)
return
o_slc
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
offsets
):
ctx
.
dtype
=
q
.
dtype
...
...
@@ -221,22 +212,25 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
window_size
,
scale
=
scale
,
offsets
=
offsets
,
token_indices
=
token_indices
)
token_indices
=
token_indices
,
)
return
o_slc
.
to
(
q
.
dtype
)
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
assert
False
,
"Window size is not supported yet"
else
:
o
=
o_slc
*
g_slc
.
unsqueeze
(
-
1
)
if
head_first
:
o
=
rearrange
(
o
,
'
b t h d -> b h t d
'
)
o
=
rearrange
(
o
,
"
b t h d -> b h t d
"
)
return
o
...
...
@@ -306,41 +298,57 @@ if __name__ == "__main__":
N
,
C_SEQ_LEN
,
H
,
HQ
,
D
,
S
,
block_size
,
dtype
=
2
,
64
,
1
,
16
,
64
,
1
,
32
,
torch
.
float16
torch
.
manual_seed
(
42
)
# randomly split the sequence into N segments
offsets
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
long
),
torch
.
arange
(
16
,
C_SEQ_LEN
)[
torch
.
randperm
(
C_SEQ_LEN
-
1
)[:
N
-
1
]],
torch
.
tensor
([
C_SEQ_LEN
],
dtype
=
torch
.
long
)
],
0
).
cuda
().
sort
()[
0
]
offsets
=
(
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
long
),
torch
.
arange
(
16
,
C_SEQ_LEN
)[
torch
.
randperm
(
C_SEQ_LEN
-
1
)[:
N
-
1
]],
torch
.
tensor
([
C_SEQ_LEN
],
dtype
=
torch
.
long
),
],
0
,
)
.
cuda
()
.
sort
()[
0
]
)
# seq-first required for inputs with variable lengths
perm_q
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
'cuda'
)
perm_k
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
'cuda'
)
perm_v
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
'cuda'
)
q
=
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_q
].
view
(
1
,
C_SEQ_LEN
,
1
,
1
).
expand
(
1
,
C_SEQ_LEN
,
HQ
,
D
).
clone
().
requires_grad_
(
True
)
k
=
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_k
].
view
(
1
,
C_SEQ_LEN
,
1
,
1
).
expand
(
1
,
C_SEQ_LEN
,
H
,
D
).
clone
().
requires_grad_
(
True
)
v
=
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_v
].
view
(
1
,
C_SEQ_LEN
,
1
,
1
).
expand
(
1
,
C_SEQ_LEN
,
H
,
D
).
clone
().
requires_grad_
(
True
)
g_slc
=
torch
.
rand
((
1
,
C_SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
'cuda'
).
requires_grad_
(
True
)
g_swa
=
torch
.
rand
((
1
,
C_SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
'cuda'
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
1
,
C_SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'cuda'
)
perm_q
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
"cuda"
)
perm_k
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
"cuda"
)
perm_v
=
torch
.
randperm
(
C_SEQ_LEN
,
device
=
"cuda"
)
q
=
(
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_q
]
.
view
(
1
,
C_SEQ_LEN
,
1
,
1
)
.
expand
(
1
,
C_SEQ_LEN
,
HQ
,
D
)
.
clone
()
.
requires_grad_
(
True
)
)
k
=
(
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_k
]
.
view
(
1
,
C_SEQ_LEN
,
1
,
1
)
.
expand
(
1
,
C_SEQ_LEN
,
H
,
D
)
.
clone
()
.
requires_grad_
(
True
)
)
v
=
(
torch
.
linspace
(
0
,
1
,
steps
=
C_SEQ_LEN
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_v
]
.
view
(
1
,
C_SEQ_LEN
,
1
,
1
)
.
expand
(
1
,
C_SEQ_LEN
,
H
,
D
)
.
clone
()
.
requires_grad_
(
True
)
)
g_slc
=
torch
.
rand
((
1
,
C_SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
"cuda"
).
requires_grad_
(
True
)
g_swa
=
torch
.
rand
((
1
,
C_SEQ_LEN
,
HQ
),
dtype
=
dtype
,
device
=
"cuda"
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
1
,
C_SEQ_LEN
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
token_indices
=
prepare_token_indices
(
offsets
).
tolist
()
block_indices
=
torch
.
full
((
1
,
C_SEQ_LEN
,
H
,
S
),
C_SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
1
,
C_SEQ_LEN
,
H
,
S
),
C_SEQ_LEN
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
i
in
range
(
C_SEQ_LEN
):
_
,
t
=
token_indices
[
i
]
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
tilelang
.
cdiv
(
t
,
block_size
)))[:
S
]
block_indices
[
0
,
i
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
0
,
i
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
1
,
C_SEQ_LEN
,
H
),
device
=
'
cuda
'
)
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
1
,
C_SEQ_LEN
,
H
),
device
=
"
cuda
"
)
ref
=
naive_nsa
(
q
=
q
,
...
...
@@ -351,7 +359,8 @@ if __name__ == "__main__":
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_size
=
block_size
,
cu_seqlens
=
offsets
)
cu_seqlens
=
offsets
,
)
tri
=
parallel_nsa
(
q
=
q
,
...
...
@@ -362,7 +371,8 @@ if __name__ == "__main__":
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_size
=
block_size
,
cu_seqlens
=
offsets
)
cu_seqlens
=
offsets
,
)
print
(
"tri"
,
tri
)
print
(
"ref"
,
ref
)
...
...
examples/deepseek_nsa/example_triton_nsa_bwd.py
View file @
29051439
...
...
@@ -8,6 +8,7 @@ import triton
import
triton.language
as
tl
import
fla
if
parse
(
fla
.
__version__
)
<
parse
(
"0.2.1"
):
from
fla.ops.common.utils
import
prepare_token_indices
else
:
...
...
@@ -17,21 +18,44 @@ from reference import naive_nsa
from
einops
import
rearrange
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
),
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
...
@@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
# else:
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_slc
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_slc
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i
in
range
(
NS
):
i_s
=
tl
.
load
(
block_indices
+
i
).
to
(
tl
.
int32
)
*
BS
...
...
@@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc
=
tl
.
load
(
p_v_slc
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_slc
=
tl
.
dot
(
b_q
,
b_k_slc
)
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
'
-inf
'
))
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
"
-inf
"
))
# [G]
b_m_slc
,
b_mp_slc
=
tl
.
maximum
(
b_m_slc
,
tl
.
max
(
b_s_slc
,
1
)),
b_m_slc
...
...
@@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices
=
prepare_token_indices
(
offsets
)
if
offsets
is
not
None
else
None
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
block_indices
=
block_indices
ctx
.
block_size
=
block_size
...
...
@@ -134,7 +154,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
ctx
.
window_size
,
scale
=
ctx
.
scale
,
offsets
=
ctx
.
offsets
,
token_indices
=
ctx
.
token_indices
)
token_indices
=
ctx
.
token_indices
,
)
return
dq
.
to
(
q
),
dk
.
to
(
k
),
dv
.
to
(
v
),
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
@@ -199,37 +220,56 @@ def parallel_nsa_fwd(
return
o_slc
,
lse_slc
,
o_swa
,
lse_swa
@
triton
.
heuristics
({
'
USE_OFFSETS
'
:
lambda
args
:
args
[
'
offsets
'
]
is
not
None
})
@
triton
.
heuristics
({
"
USE_OFFSETS
"
:
lambda
args
:
args
[
"
offsets
"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
parallel_nsa_bwd_kernel_dkv
(
q
,
k
,
v
,
lse_slc
,
lse_swa
,
delta_slc
,
delta_swa
,
do_slc
,
do_swa
,
dk
,
dv
,
block_mask
,
offsets
,
chunk_indices
,
scale
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
M
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
):
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
parallel_nsa_bwd_kernel_dkv
(
q
,
k
,
v
,
lse_slc
,
lse_swa
,
delta_slc
,
delta_swa
,
do_slc
,
do_swa
,
dk
,
dv
,
block_mask
,
offsets
,
chunk_indices
,
scale
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
M
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
):
i_v
,
i_s
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
USE_OFFSETS
:
i_n
,
i_s
=
tl
.
load
(
chunk_indices
+
i_s
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_s
*
2
+
1
).
to
(
tl
.
int32
)
i_n
,
i_s
=
tl
.
load
(
chunk_indices
+
i_s
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_s
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
offsets
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
offsets
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_s
*
BS
,
0
),
(
BS
,
BK
),
(
1
,
0
))
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_s
*
BS
,
i_v
*
BV
),
(
BS
,
BV
),
(
1
,
0
))
p_dk
=
tl
.
make_block_ptr
(
dk
+
(
i_v
*
B
*
T
*
H
+
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_s
*
BS
,
0
),
(
BS
,
BK
),
(
1
,
0
))
p_dv
=
tl
.
make_block_ptr
(
dv
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_s
*
BS
,
i_v
*
BV
),
(
BS
,
BV
),
(
1
,
0
))
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_s
*
BS
,
0
),
(
BS
,
BK
),
(
1
,
0
))
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_s
*
BS
,
i_v
*
BV
),
(
BS
,
BV
),
(
1
,
0
))
p_dk
=
tl
.
make_block_ptr
(
dk
+
(
i_v
*
B
*
T
*
H
+
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_s
*
BS
,
0
),
(
BS
,
BK
),
(
1
,
0
))
p_dv
=
tl
.
make_block_ptr
(
dv
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_s
*
BS
,
i_v
*
BV
),
(
BS
,
BV
),
(
1
,
0
))
# [BS, BK]
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
...
...
@@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
for
i
in
range
(
i_s
*
BS
,
T
):
b_m_slc
=
tl
.
load
(
block_mask
+
(
bos
+
i
)
*
H
*
M
+
i_h
*
M
+
i_s
)
if
b_m_slc
:
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_do_slc
=
tl
.
make_block_ptr
(
do_slc
+
(
bos
+
i
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_do_slc
=
tl
.
make_block_ptr
(
do_slc
+
(
bos
+
i
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
p_delta_slc
=
delta_slc
+
(
bos
+
i
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
...
...
@@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
if
WS
>
0
:
o_s
=
i_s
*
BS
+
tl
.
arange
(
0
,
BS
)
if
max
(
i_s
*
BS
,
i
-
WS
+
1
)
<
min
((
i_s
+
1
)
*
BS
,
i
+
1
):
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_do_swa
=
tl
.
make_block_ptr
(
do_swa
+
(
bos
+
i
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_do_swa
=
tl
.
make_block_ptr
(
do_swa
+
(
bos
+
i
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_swa
=
lse_swa
+
(
bos
+
i
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
p_delta_swa
=
delta_swa
+
(
bos
+
i
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
...
...
@@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
tl
.
store
(
p_dv
,
b_dv
.
to
(
p_dv
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
(
{
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
)})
@
triton
.
heuristics
({
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
)})
@
triton
.
jit
def
parallel_nsa_kernel_mask
(
block_indices
,
block_counts
,
block_mask
,
T
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
NS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_kernel_mask
(
block_indices
,
block_counts
,
block_mask
,
T
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
NS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_b
,
i_hs
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_h
,
i_s
=
i_hs
//
S
,
i_hs
%
S
...
...
@@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons
b_m
=
b_i
*
BS
<=
i_t
if
b_i
<
NS
and
b_i
>=
0
:
tl
.
store
(
block_mask
+
i_b
*
T
*
H
*
NS
+
i_t
*
H
*
NS
+
i_h
*
NS
+
b_i
,
b_m
.
to
(
block_mask
.
dtype
.
element_ty
))
tl
.
store
(
block_mask
+
i_b
*
T
*
H
*
NS
+
i_t
*
H
*
NS
+
i_h
*
NS
+
b_i
,
b_m
.
to
(
block_mask
.
dtype
.
element_ty
))
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
)
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
parallel_nsa_bwd_kernel_dq
(
q
,
k
,
v
,
lse_slc
,
delta_slc
,
do_slc
,
lse_swa
,
delta_swa
,
do_swa
,
dq
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
parallel_nsa_bwd_kernel_dq
(
q
,
k
,
v
,
lse_slc
,
delta_slc
,
do_slc
,
lse_swa
,
delta_swa
,
do_swa
,
dq
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
USE_OFFSETS
:
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
offsets
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
offsets
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
...
...
@@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del
tl
.
store
(
p_dq
,
(
b_dq_slc
+
b_dq_swa
).
to
(
p_dq
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
),
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
USE_OFFSETS
:
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
offsets
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
offsets
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
...
...
@@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
else
:
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_slc
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_slc
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i
in
range
(
NS
):
i_s
=
tl
.
load
(
block_indices
+
i
).
to
(
tl
.
int32
)
*
BS
...
...
@@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc
=
tl
.
load
(
p_v_slc
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_slc
=
tl
.
dot
(
b_q
,
b_k_slc
)
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
'
-inf
'
))
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
"
-inf
"
))
# [G]
b_m_slc
,
b_mp_slc
=
tl
.
maximum
(
b_m_slc
,
tl
.
max
(
b_s_slc
,
1
)),
b_m_slc
...
...
@@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
tl
.
store
(
p_lse_slc
,
b_m_slc
.
to
(
p_lse_slc
.
dtype
.
element_ty
))
if
WS
>
0
:
p_o_swa
=
tl
.
make_block_ptr
(
o_swa
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_swa
=
tl
.
make_block_ptr
(
o_swa
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_swa
=
lse_swa
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_swa
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_swa
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_swa
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_swa
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i_s
in
range
(
max
(
0
,
i_t
-
WS
+
1
),
i_t
+
1
,
BS
):
p_k_swa
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
H
*
K
),
(
0
,
i_s
),
(
BK
,
BS
),
(
0
,
1
))
...
...
@@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_swa
=
tl
.
load
(
p_v_swa
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_swa
=
tl
.
dot
(
b_q
,
b_k_swa
)
b_s_swa
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_swa
,
float
(
'
-inf
'
))
b_s_swa
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_swa
,
float
(
"
-inf
"
))
# [G]
b_m_swa
,
b_mp_swa
=
tl
.
maximum
(
b_m_swa
,
tl
.
max
(
b_s_swa
,
1
)),
b_m_swa
...
...
@@ -593,14 +680,8 @@ def parallel_nsa_block_mask(
block_mask
=
torch
.
zeros
(
B
,
T
,
H
,
NS
,
dtype
=
torch
.
bool
,
device
=
block_indices
.
device
)
parallel_nsa_kernel_mask
[(
T
,
B
,
H
*
S
)](
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_mask
=
block_mask
,
T
=
T
,
H
=
H
,
S
=
S
,
BS
=
BS
,
NS
=
NS
)
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_mask
=
block_mask
,
T
=
T
,
H
=
H
,
S
=
S
,
BS
=
BS
,
NS
=
NS
)
return
block_mask
...
...
@@ -676,7 +757,8 @@ def parallel_nsa_bwd(
BS
=
BS
,
WS
=
WS
,
BK
=
BK
,
BV
=
BV
)
BV
=
BV
,
)
dq
=
dq
.
sum
(
0
)
if
offsets
is
not
None
:
...
...
@@ -719,14 +801,14 @@ def parallel_nsa_bwd(
BS
=
BS
,
WS
=
WS
,
BK
=
BK
,
BV
=
BV
)
BV
=
BV
,
)
dk
=
dk
.
sum
(
0
)
return
dq
,
dk
,
dv
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -749,7 +831,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
window_size
,
scale
=
scale
,
offsets
=
offsets
,
token_indices
=
token_indices
)
token_indices
=
token_indices
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o_slc
,
lse_slc
,
o_swa
,
lse_swa
)
ctx
.
block_indices
=
block_indices
ctx
.
block_counts
=
block_counts
...
...
@@ -781,22 +864,25 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
ctx
.
window_size
,
scale
=
ctx
.
scale
,
offsets
=
ctx
.
offsets
,
token_indices
=
ctx
.
token_indices
)
token_indices
=
ctx
.
token_indices
,
)
return
dq
.
to
(
q
),
dk
.
to
(
k
),
dv
.
to
(
v
),
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
o
=
torch
.
addcmul
(
o_slc
*
g_slc
.
unsqueeze
(
-
1
),
o_swa
,
g_swa
.
unsqueeze
(
-
1
))
else
:
o
=
o_slc
*
g_slc
.
unsqueeze
(
-
1
)
if
head_first
:
o
=
rearrange
(
o
,
'
b t h d -> b h t d
'
)
o
=
rearrange
(
o
,
"
b t h d -> b h t d
"
)
return
o
if
__name__
==
"__main__"
:
B
,
T
,
H
,
HQ
,
D
,
S
,
block_size
,
dtype
=
2
,
64
,
1
,
16
,
32
,
1
,
32
,
torch
.
float16
torch
.
random
.
manual_seed
(
0
)
q
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
k
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
v
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
B
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
q
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
k
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
v
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
)
block_indices
=
torch
.
full
((
B
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
b
in
range
(
B
):
for
t
in
range
(
T
):
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
S
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
T
,
H
),
device
=
'
cuda
'
)
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
T
,
H
),
device
=
"
cuda
"
)
ref
=
naive_nsa
(
q
=
q
,
...
...
examples/deepseek_nsa/example_triton_nsa_fwd.py
View file @
29051439
...
...
@@ -8,6 +8,7 @@ import triton
import
triton.language
as
tl
import
fla
if
parse
(
fla
.
__version__
)
<
parse
(
"0.2.1"
):
from
fla.ops.common.utils
import
prepare_token_indices
else
:
...
...
@@ -17,21 +18,44 @@ from reference import naive_nsa
from
einops
import
rearrange
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
),
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
...
@@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
# else:
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_slc
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_slc
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i
in
range
(
NS
):
i_s
=
tl
.
load
(
block_indices
+
i
).
to
(
tl
.
int32
)
*
BS
...
...
@@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc
=
tl
.
load
(
p_v_slc
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_slc
=
tl
.
dot
(
b_q
,
b_k_slc
)
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
'
-inf
'
))
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
"
-inf
"
))
# [G]
b_m_slc
,
b_mp_slc
=
tl
.
maximum
(
b_m_slc
,
tl
.
max
(
b_s_slc
,
1
)),
b_m_slc
...
...
@@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices
=
prepare_token_indices
(
offsets
)
if
offsets
is
not
None
else
None
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
o
,
lse
=
parallel_nsa_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
block_indices
=
block_indices
,
block_size
=
block_size
,
scale
=
scale
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
block_indices
=
block_indices
ctx
.
block_size
=
block_size
...
...
@@ -177,7 +197,6 @@ def parallel_nsa_fwd(
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -200,7 +219,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
window_size
,
scale
=
scale
,
offsets
=
offsets
,
token_indices
=
token_indices
)
token_indices
=
token_indices
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o_slc
,
lse_slc
,
o_swa
,
lse_swa
)
ctx
.
block_indices
=
block_indices
ctx
.
block_counts
=
block_counts
...
...
@@ -212,18 +232,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return
o_slc
.
to
(
q
.
dtype
),
o_swa
.
to
(
q
.
dtype
)
if
o_swa
is
not
None
else
o_swa
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
o
=
torch
.
addcmul
(
o_slc
*
g_slc
.
unsqueeze
(
-
1
),
o_swa
,
g_swa
.
unsqueeze
(
-
1
))
else
:
o
=
o_slc
*
g_slc
.
unsqueeze
(
-
1
)
if
head_first
:
o
=
rearrange
(
o
,
'
b t h d -> b h t d
'
)
o
=
rearrange
(
o
,
"
b t h d -> b h t d
"
)
return
o
if
__name__
==
"__main__"
:
B
,
T
,
H
,
HQ
,
D
,
S
,
block_size
,
dtype
=
2
,
64
,
1
,
16
,
32
,
1
,
32
,
torch
.
float16
torch
.
random
.
manual_seed
(
0
)
q
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
k
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
v
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'
cuda
'
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
B
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
q
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
k
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
v
=
torch
.
randn
((
B
,
T
,
H
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_slc
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
g_swa
=
torch
.
ones
((
B
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"
cuda
"
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
B
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"
cuda
"
)
block_indices
=
torch
.
full
((
B
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
b
in
range
(
B
):
for
t
in
range
(
T
):
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
(
t
//
block_size
)))[:
S
]
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
b
,
t
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
T
,
H
),
device
=
'
cuda
'
)
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
B
,
T
,
H
),
device
=
"
cuda
"
)
ref
=
naive_nsa
(
q
=
q
,
...
...
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
View file @
29051439
...
...
@@ -8,6 +8,7 @@ import triton
import
triton.language
as
tl
import
fla
if
parse
(
fla
.
__version__
)
<
parse
(
"0.2.1"
):
from
fla.ops.common.utils
import
prepare_token_indices
else
:
...
...
@@ -17,27 +18,49 @@ from reference import naive_nsa
from
einops
import
rearrange
@
triton
.
heuristics
({
'USE_OFFSETS'
:
lambda
args
:
args
[
'offsets'
]
is
not
None
,
'USE_BLOCK_COUNTS'
:
lambda
args
:
isinstance
(
args
[
'block_counts'
],
torch
.
Tensor
),
})
@
triton
.
heuristics
(
{
"USE_OFFSETS"
:
lambda
args
:
args
[
"offsets"
]
is
not
None
,
"USE_BLOCK_COUNTS"
:
lambda
args
:
isinstance
(
args
[
"block_counts"
],
torch
.
Tensor
),
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
'
BS
'
,
'
BK
'
,
'
BV
'
],
key
=
[
"
BS
"
,
"
BK
"
,
"
BV
"
],
)
@
triton
.
jit
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
):
def
parallel_nsa_fwd_kernel
(
q
,
k
,
v
,
o_slc
,
o_swa
,
lse_slc
,
lse_swa
,
scale
,
block_indices
,
block_counts
,
offsets
,
token_indices
,
T
,
H
:
tl
.
constexpr
,
HQ
:
tl
.
constexpr
,
G
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
WS
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_OFFSETS
:
tl
.
constexpr
,
USE_BLOCK_COUNTS
:
tl
.
constexpr
,
):
i_t
,
i_v
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
USE_OFFSETS
:
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
i_n
,
i_t
=
tl
.
load
(
token_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
token_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
offsets
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
offsets
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
...
...
@@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
else
:
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
(
1
,
0
))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_slc
=
tl
.
make_block_ptr
(
o_slc
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_slc
=
lse_slc
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_slc
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_slc
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_slc
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i
in
range
(
NS
):
i_s
=
tl
.
load
(
block_indices
+
i
).
to
(
tl
.
int32
)
*
BS
...
...
@@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc
=
tl
.
load
(
p_v_slc
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_slc
=
tl
.
dot
(
b_q
,
b_k_slc
)
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
'
-inf
'
))
b_s_slc
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_slc
,
float
(
"
-inf
"
))
# [G]
b_m_slc
,
b_mp_slc
=
tl
.
maximum
(
b_m_slc
,
tl
.
max
(
b_s_slc
,
1
)),
b_m_slc
...
...
@@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
tl
.
store
(
p_lse_slc
,
b_m_slc
.
to
(
p_lse_slc
.
dtype
.
element_ty
))
if
WS
>
0
:
p_o_swa
=
tl
.
make_block_ptr
(
o_swa
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_o_swa
=
tl
.
make_block_ptr
(
o_swa
+
(
bos
+
i_t
)
*
HQ
*
V
,
(
HQ
,
V
),
(
V
,
1
),
(
i_h
*
G
,
i_v
*
BV
),
(
G
,
BV
),
(
1
,
0
))
p_lse_swa
=
lse_swa
+
(
bos
+
i_t
)
*
HQ
+
i_h
*
G
+
tl
.
arange
(
0
,
G
)
# [G, BV]
b_o_swa
=
tl
.
zeros
([
G
,
BV
],
dtype
=
tl
.
float32
)
b_m_swa
=
tl
.
full
([
G
],
float
(
'
-inf
'
),
dtype
=
tl
.
float32
)
b_m_swa
=
tl
.
full
([
G
],
float
(
"
-inf
"
),
dtype
=
tl
.
float32
)
b_acc_swa
=
tl
.
zeros
([
G
],
dtype
=
tl
.
float32
)
for
i_s
in
range
(
max
(
0
,
i_t
-
WS
+
1
),
i_t
+
1
,
BS
):
p_k_swa
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
H
*
K
),
(
0
,
i_s
),
(
BK
,
BS
),
(
0
,
1
))
...
...
@@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_swa
=
tl
.
load
(
p_v_swa
,
boundary_check
=
(
0
,
1
))
# [G, BS]
b_s_swa
=
tl
.
dot
(
b_q
,
b_k_swa
)
b_s_swa
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_swa
,
float
(
'
-inf
'
))
b_s_swa
=
tl
.
where
((
i_t
>=
(
i_s
+
tl
.
arange
(
0
,
BS
)))[
None
,
:],
b_s_swa
,
float
(
"
-inf
"
))
# [G]
b_m_swa
,
b_mp_swa
=
tl
.
maximum
(
b_m_swa
,
tl
.
max
(
b_s_swa
,
1
)),
b_m_swa
...
...
@@ -196,7 +216,6 @@ def parallel_nsa_fwd(
@
torch
.
compile
class
ParallelNSAFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
contiguous
@
autocast_custom_fwd
...
...
@@ -219,7 +238,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size
=
window_size
,
scale
=
scale
,
offsets
=
offsets
,
token_indices
=
token_indices
)
token_indices
=
token_indices
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o_slc
,
lse_slc
,
o_swa
,
lse_swa
)
ctx
.
block_indices
=
block_indices
ctx
.
block_counts
=
block_counts
...
...
@@ -231,18 +251,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return
o_slc
.
to
(
q
.
dtype
),
o_swa
.
to
(
q
.
dtype
)
if
o_swa
is
not
None
else
o_swa
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
parallel_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
assert
q
.
shape
[
2
]
%
(
k
.
shape
[
2
]
*
16
)
==
0
,
"Group size must be a multiple of 16 in NSA"
if
isinstance
(
block_counts
,
int
):
block_indices
=
block_indices
[:,
:,
:,
:
block_counts
]
block_counts
=
None
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
o_slc
,
o_swa
=
ParallelNSAFunction
.
apply
(
q
,
k
,
v
,
block_indices
,
block_counts
,
block_size
,
window_size
,
scale
,
cu_seqlens
)
if
window_size
>
0
:
o
=
torch
.
addcmul
(
o_slc
*
g_slc
.
unsqueeze
(
-
1
),
o_swa
,
g_swa
.
unsqueeze
(
-
1
))
else
:
o
=
o_slc
*
g_slc
.
unsqueeze
(
-
1
)
if
head_first
:
o
=
rearrange
(
o
,
'
b t h d -> b h t d
'
)
o
=
rearrange
(
o
,
"
b t h d -> b h t d
"
)
return
o
...
...
@@ -312,38 +332,35 @@ if __name__ == "__main__":
N
,
T
,
H
,
HQ
,
D
,
S
,
block_size
,
dtype
=
2
,
64
,
1
,
16
,
64
,
1
,
32
,
torch
.
float16
torch
.
manual_seed
(
42
)
# randomly split the sequence into N segments
offsets
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
long
),
torch
.
arange
(
16
,
T
)[
torch
.
randperm
(
T
-
1
)[:
N
-
1
]],
torch
.
tensor
([
T
],
dtype
=
torch
.
long
)
],
0
).
cuda
().
sort
()[
0
]
offsets
=
(
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
long
),
torch
.
arange
(
16
,
T
)[
torch
.
randperm
(
T
-
1
)[:
N
-
1
]],
torch
.
tensor
([
T
],
dtype
=
torch
.
long
)],
0
,
)
.
cuda
()
.
sort
()[
0
]
)
# offsets.shape is [N+1]
# seq-first required for inputs with variable lengths
perm_q
=
torch
.
randperm
(
T
,
device
=
'cuda'
)
perm_k
=
torch
.
randperm
(
T
,
device
=
'cuda'
)
perm_v
=
torch
.
randperm
(
T
,
device
=
'cuda'
)
q
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_q
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
HQ
,
D
).
clone
().
requires_grad_
(
True
)
k
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_k
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
H
,
D
).
clone
().
requires_grad_
(
True
)
v
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
'cuda'
)[
perm_v
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
H
,
D
).
clone
().
requires_grad_
(
True
)
g_slc
=
torch
.
rand
((
1
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'cuda'
).
requires_grad_
(
True
)
g_swa
=
torch
.
rand
((
1
,
T
,
HQ
),
dtype
=
dtype
,
device
=
'cuda'
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
1
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
'cuda'
)
perm_q
=
torch
.
randperm
(
T
,
device
=
"cuda"
)
perm_k
=
torch
.
randperm
(
T
,
device
=
"cuda"
)
perm_v
=
torch
.
randperm
(
T
,
device
=
"cuda"
)
q
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_q
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
HQ
,
D
).
clone
().
requires_grad_
(
True
)
k
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_k
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
H
,
D
).
clone
().
requires_grad_
(
True
)
v
=
torch
.
linspace
(
0
,
1
,
steps
=
T
,
dtype
=
dtype
,
device
=
"cuda"
)[
perm_v
].
view
(
1
,
T
,
1
,
1
).
expand
(
1
,
T
,
H
,
D
).
clone
().
requires_grad_
(
True
)
g_slc
=
torch
.
rand
((
1
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"cuda"
).
requires_grad_
(
True
)
g_swa
=
torch
.
rand
((
1
,
T
,
HQ
),
dtype
=
dtype
,
device
=
"cuda"
).
requires_grad_
(
True
)
do
=
torch
.
randn
((
1
,
T
,
HQ
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
token_indices
=
prepare_token_indices
(
offsets
).
tolist
()
block_indices
=
torch
.
full
((
1
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
'
cuda
'
)
block_indices
=
torch
.
full
((
1
,
T
,
H
,
S
),
T
,
dtype
=
torch
.
long
,
device
=
"
cuda
"
)
for
i
in
range
(
T
):
_
,
t
=
token_indices
[
i
]
for
h
in
range
(
H
):
i_i
=
torch
.
randperm
(
max
(
1
,
triton
.
cdiv
(
t
,
block_size
)))[:
S
]
block_indices
[
0
,
i
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
[
0
,
i
,
h
,
:
len
(
i_i
)]
=
i_i
block_indices
=
block_indices
.
sort
(
-
1
)[
0
]
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
1
,
T
,
H
),
device
=
'
cuda
'
)
block_counts
=
torch
.
randint
(
1
,
S
+
1
,
(
1
,
T
,
H
),
device
=
"
cuda
"
)
ref
=
naive_nsa
(
q
=
q
,
...
...
@@ -354,7 +371,8 @@ if __name__ == "__main__":
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_size
=
block_size
,
cu_seqlens
=
offsets
)
cu_seqlens
=
offsets
,
)
tri
=
parallel_nsa
(
q
=
q
,
...
...
@@ -365,7 +383,8 @@ if __name__ == "__main__":
block_indices
=
block_indices
,
block_counts
=
block_counts
,
block_size
=
block_size
,
cu_seqlens
=
offsets
)
cu_seqlens
=
offsets
,
)
print
(
"tri"
,
tri
)
print
(
"ref"
,
ref
)
...
...
examples/deepseek_nsa/reference.py
View file @
29051439
...
...
@@ -6,18 +6,20 @@ from typing import Union
from
einops
import
rearrange
,
repeat
def
naive_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
)
->
torch
.
Tensor
:
def
naive_nsa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g_slc
:
torch
.
Tensor
,
g_swa
:
torch
.
Tensor
,
block_indices
:
torch
.
LongTensor
,
block_counts
:
Optional
[
Union
[
torch
.
LongTensor
,
int
]]
=
None
,
block_size
:
int
=
64
,
window_size
:
int
=
0
,
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
)
->
torch
.
Tensor
:
r
"""
Args:
q (torch.Tensor):
...
...
@@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
cu_seqlens
is
not
None
:
assert
q
.
shape
[
0
]
==
1
,
"batch size must be 1 when cu_seqlens are provided"
if
head_first
:
raise
RuntimeError
(
"Sequences with variable lengths are not supported for head-first mode"
)
raise
RuntimeError
(
"Sequences with variable lengths are not supported for head-first mode"
)
if
head_first
:
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t d -> b t h d'
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t -> b t h'
),
(
g_slc
,
g_swa
))
q
,
k
,
v
,
block_indices
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t d -> b t h d"
),
(
q
,
k
,
v
,
block_indices
))
g_slc
,
g_swa
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t -> b t h"
),
(
g_slc
,
g_swa
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
rearrange
(
block_counts
,
'
b h t -> b t h
'
)
block_counts
=
rearrange
(
block_counts
,
"
b h t -> b t h
"
)
dtype
=
q
.
dtype
G
=
q
.
shape
[
2
]
//
k
.
shape
[
2
]
BS
=
block_size
S
=
block_indices
.
shape
[
-
1
]
k
,
v
,
block_indices
=
(
repeat
(
x
,
'
b t h d -> b t (h g) d
'
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
k
,
v
,
block_indices
=
(
repeat
(
x
,
"
b t h d -> b t (h g) d
"
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
if
isinstance
(
block_counts
,
torch
.
Tensor
):
block_counts
=
repeat
(
block_counts
,
'
b t h -> b t (h g)
'
,
g
=
G
)
block_counts
=
repeat
(
block_counts
,
"
b t h -> b t (h g)
"
,
g
=
G
)
c
=
torch
.
arange
(
S
).
repeat_interleave
(
BS
).
unsqueeze
(
1
).
expand
(
-
1
,
q
.
shape
[
2
]).
to
(
q
.
device
)
q
,
k
,
v
=
map
(
lambda
x
:
x
.
float
(),
(
q
,
k
,
v
))
...
...
@@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor,
if
cu_seqlens
is
None
:
varlen
=
False
B
,
T
=
q
.
shape
[:
2
]
cu_seqlens
=
torch
.
cat
(
[
block_indices
.
new_tensor
(
range
(
0
,
B
*
T
,
T
)),
block_indices
.
new_tensor
([
B
*
T
])])
cu_seqlens
=
torch
.
cat
([
block_indices
.
new_tensor
(
range
(
0
,
B
*
T
,
T
)),
block_indices
.
new_tensor
([
B
*
T
])])
for
i
in
range
(
len
(
cu_seqlens
)
-
1
):
if
not
varlen
:
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
q
[
i
],
k
[
i
],
v
[
i
],
g_slc
[
i
],
g_swa
[
i
],
block_indices
[
i
]
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
q
[
i
],
k
[
i
],
v
[
i
],
g_slc
[
i
],
g_swa
[
i
],
block_indices
[
i
]
if
isinstance
(
block_counts
,
torch
.
Tensor
):
s_b
=
block_counts
[
i
]
else
:
...
...
@@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor,
else
:
T
=
cu_seqlens
[
i
+
1
]
-
cu_seqlens
[
i
]
q_b
,
k_b
,
v_b
,
g_slc_b
,
g_swa_b
,
i_b
=
map
(
lambda
x
:
x
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]],
(
q
,
k
,
v
,
g_slc
,
g_swa
,
block_indices
)
)
lambda
x
:
x
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]],
(
q
,
k
,
v
,
g_slc
,
g_swa
,
block_indices
)
)
if
isinstance
(
block_counts
,
torch
.
Tensor
):
s_b
=
block_counts
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]]
s_b
=
block_counts
[
0
][
cu_seqlens
[
i
]
:
cu_seqlens
[
i
+
1
]]
else
:
s_b
=
block_counts
...
...
@@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor,
else
:
s_i
=
s_b
# [S*BS, HQ, -1]
k_i_slc
,
v_i_slc
=
map
(
lambda
x
:
x
.
gather
(
0
,
i_i
.
clamp
(
0
,
T
-
1
).
unsqueeze
(
-
1
).
expand
(
*
i_i
.
shape
,
x
.
shape
[
-
1
])),
(
k_b
,
v_b
))
k_i_slc
,
v_i_slc
=
map
(
lambda
x
:
x
.
gather
(
0
,
i_i
.
clamp
(
0
,
T
-
1
).
unsqueeze
(
-
1
).
expand
(
*
i_i
.
shape
,
x
.
shape
[
-
1
])),
(
k_b
,
v_b
))
# [S*BS, HQ]
attn_slc
=
torch
.
einsum
(
'h d, n h d -> n h'
,
q_i
,
k_i_slc
).
masked_fill
(
torch
.
logical_or
(
i_i
<
0
,
i_i
>
i_q
)
|
(
c
>=
s_i
if
block_counts
is
not
None
else
False
),
float
(
'-inf'
)).
softmax
(
0
)
attn_slc
=
(
torch
.
einsum
(
"h d, n h d -> n h"
,
q_i
,
k_i_slc
)
.
masked_fill
(
torch
.
logical_or
(
i_i
<
0
,
i_i
>
i_q
)
|
(
c
>=
s_i
if
block_counts
is
not
None
else
False
),
float
(
"-inf"
))
.
softmax
(
0
)
)
if
not
varlen
:
o_slc
[
i
,
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
o_slc
[
i
,
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
else
:
o_slc
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
o_slc
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_slc
,
v_i_slc
)
*
g_slc_i
.
unsqueeze
(
-
1
)
if
window_size
>
0
:
k_i_swa
,
v_i_swa
=
map
(
lambda
x
:
x
[
max
(
0
,
i_q
-
window_size
+
1
):
i_q
+
1
],
(
k_b
,
v_b
))
attn_swa
=
torch
.
einsum
(
'h d, n h d -> n h'
,
q_i
,
k_i_swa
).
softmax
(
0
)
k_i_swa
,
v_i_swa
=
map
(
lambda
x
:
x
[
max
(
0
,
i_q
-
window_size
+
1
)
:
i_q
+
1
],
(
k_b
,
v_b
))
attn_swa
=
torch
.
einsum
(
"h d, n h d -> n h"
,
q_i
,
k_i_swa
).
softmax
(
0
)
if
not
varlen
:
o_swa
[
i
,
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
o_swa
[
i
,
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
else
:
o_swa
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
'n h, n h v -> h v'
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
o_swa
[
0
][
cu_seqlens
[
i
]
+
i_q
]
=
torch
.
einsum
(
"n h, n h v -> h v"
,
attn_swa
,
v_i_swa
)
*
g_swa_i
.
unsqueeze
(
-
1
)
if
head_first
:
o_slc
=
rearrange
(
o_slc
,
'
b t h d -> b h t d
'
)
o_swa
=
rearrange
(
o_swa
,
'
b t h d -> b h t d
'
)
o_slc
=
rearrange
(
o_slc
,
"
b t h d -> b h t d
"
)
o_swa
=
rearrange
(
o_swa
,
"
b t h d -> b h t d
"
)
return
o_slc
.
to
(
dtype
)
+
o_swa
.
to
(
dtype
)
if
o_swa
is
not
None
else
o_slc
.
to
(
dtype
)
...
...
@@ -187,7 +178,7 @@ def naive_nsa_simple(
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
dtype
=
q
.
dtype
HQ
=
q
.
shape
[
2
]
...
...
@@ -197,8 +188,8 @@ def naive_nsa_simple(
BS
=
block_size
S
=
block_indices
.
shape
[
-
1
]
SELECTED_BLOCKS_SIZE
=
S
*
BS
k
,
v
,
block_indices
=
(
repeat
(
x
,
'
b t h d -> b t (h g) d
'
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
block_counts
=
repeat
(
block_counts
,
'
b t h -> b t (h g)
'
,
g
=
G
)
k
,
v
,
block_indices
=
(
repeat
(
x
,
"
b t h d -> b t (h g) d
"
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
block_counts
=
repeat
(
block_counts
,
"
b t h -> b t (h g)
"
,
g
=
G
)
c
=
torch
.
arange
(
S
).
repeat_interleave
(
BS
).
unsqueeze
(
1
).
expand
(
-
1
,
q
.
shape
[
2
]).
to
(
q
.
device
)
q
,
k
,
v
=
map
(
lambda
x
:
x
.
float
(),
(
q
,
k
,
v
))
o
=
torch
.
zeros_like
(
v
)
...
...
@@ -228,10 +219,10 @@ def naive_nsa_simple(
v_i
[
t
,
h
]
=
v_b
[
selected_block_index
,
h
,
:]
# [S*BS, HQ]
attn
=
torch
.
einsum
(
'
h d, n h d -> n h
'
,
q_i
,
k_i
)
attn
=
attn
.
masked_fill
((
i_i
>
i_q
)
|
(
c
>=
s_i
),
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
h d, n h d -> n h
"
,
q_i
,
k_i
)
attn
=
attn
.
masked_fill
((
i_i
>
i_q
)
|
(
c
>=
s_i
),
float
(
"
-inf
"
))
attn
=
torch
.
softmax
(
attn
,
dim
=
0
)
o
[
i
,
i_q
]
=
torch
.
einsum
(
'
n h, n h v -> h v
'
,
attn
,
v_i
)
o
[
i
,
i_q
]
=
torch
.
einsum
(
"
n h, n h v -> h v
"
,
attn
,
v_i
)
return
o
.
to
(
dtype
)
...
...
@@ -265,7 +256,7 @@ def naive_nsa_simple_inference(
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale
=
k
.
shape
[
-
1
]
**-
0.5
scale
=
k
.
shape
[
-
1
]
**
-
0.5
dtype
=
q
.
dtype
HQ
=
q
.
shape
[
2
]
...
...
@@ -275,8 +266,8 @@ def naive_nsa_simple_inference(
BS
=
block_size
S
=
block_indices
.
shape
[
-
1
]
SELECTED_BLOCKS_SIZE
=
S
*
BS
k
,
v
,
block_indices
=
(
repeat
(
x
,
'
b t h d -> b t (h g) d
'
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
block_counts
=
repeat
(
block_counts
,
'
b t h -> b t (h g)
'
,
g
=
G
)
k
,
v
,
block_indices
=
(
repeat
(
x
,
"
b t h d -> b t (h g) d
"
,
g
=
G
)
for
x
in
(
k
,
v
,
block_indices
))
block_counts
=
repeat
(
block_counts
,
"
b t h -> b t (h g)
"
,
g
=
G
)
c
=
torch
.
arange
(
S
).
repeat_interleave
(
BS
).
unsqueeze
(
1
).
expand
(
-
1
,
q
.
shape
[
2
]).
to
(
q
.
device
)
q
,
k
,
v
=
map
(
lambda
x
:
x
.
float
(),
(
q
,
k
,
v
))
o
=
torch
.
zeros_like
(
q
)
...
...
@@ -306,9 +297,9 @@ def naive_nsa_simple_inference(
v_i
[
t
,
h
]
=
v_b
[
selected_block_index
,
h
,
:]
# [S*BS, HQ]
attn
=
torch
.
einsum
(
'
h d, n h d -> n h
'
,
q_i
,
k_i
)
attn
=
attn
.
masked_fill
((
c
>=
s_i
),
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
h d, n h d -> n h
"
,
q_i
,
k_i
)
attn
=
attn
.
masked_fill
((
c
>=
s_i
),
float
(
"
-inf
"
))
attn
=
torch
.
softmax
(
attn
,
dim
=
0
)
o
[
i
,
0
]
=
torch
.
einsum
(
'
n h, n h v -> h v
'
,
attn
,
v_i
)
o
[
i
,
0
]
=
torch
.
einsum
(
"
n h, n h v -> h v
"
,
attn
,
v_i
)
return
o
.
to
(
dtype
)
examples/deepseek_v32/fp8_lighting_indexer.py
View file @
29051439
...
...
@@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai
if
should_raise
:
assert
False
if
not
torch
.
isclose
(
a
.
masked_fill
(
a_finite
,
0
),
b
.
masked_fill
(
b_finite
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
,
a
.
masked_fill
(
a_finite
,
0
),
b
.
masked_fill
(
b_finite
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
,
).
all
():
display_error_message
(
f
"
{
tensor_name
}
Error: nonfinite value mismatch"
)
if
should_raise
:
...
...
@@ -55,13 +55,10 @@ def get_configs():
threads
=
[
128
,
256
],
block_Q
=
[
1
,
2
,
4
],
)
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
class
SupplyProg
:
def
__init__
(
self
):
self
.
tensors_dict
=
{}
...
...
@@ -88,7 +85,8 @@ supply_prog = SupplyProg()
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},)
},
)
def
mqa_attn_return_logits
(
heads
,
index_dim
,
...
...
@@ -113,16 +111,15 @@ def mqa_attn_return_logits(
@
T
.
prim_func
def
mqa_attn_return_logits_kernel
(
IndexQ
:
T
.
Tensor
(
index_q_shape
,
dtype
),
# type: ignore
IndexK
:
T
.
Tensor
(
index_k_shape
,
dtype
),
# type: ignore
IndexKScale
:
T
.
Tensor
(
index_k_scale_shape
,
accum_dtype
),
# type: ignore
Logits
:
T
.
Tensor
(
logits_shape
,
accum_dtype
),
# type: ignore
Weights
:
T
.
Tensor
([
seq_len
,
heads
],
accum_dtype
),
# type: ignore
CuSeqLenKS
:
T
.
Tensor
([
seq_len
],
index_dtype
),
# type: ignore
CuSeqLenKE
:
T
.
Tensor
([
seq_len
],
index_dtype
),
# type: ignore
IndexQ
:
T
.
Tensor
(
index_q_shape
,
dtype
),
# type: ignore
IndexK
:
T
.
Tensor
(
index_k_shape
,
dtype
),
# type: ignore
IndexKScale
:
T
.
Tensor
(
index_k_scale_shape
,
accum_dtype
),
# type: ignore
Logits
:
T
.
Tensor
(
logits_shape
,
accum_dtype
),
# type: ignore
Weights
:
T
.
Tensor
([
seq_len
,
heads
],
accum_dtype
),
# type: ignore
CuSeqLenKS
:
T
.
Tensor
([
seq_len
],
index_dtype
),
# type: ignore
CuSeqLenKE
:
T
.
Tensor
([
seq_len
],
index_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_Q
),
threads
=
threads
)
as
bx
:
index_q_shared
=
T
.
alloc_shared
([
block_Q
*
heads
,
index_dim
],
dtype
)
index_k_shared
=
T
.
alloc_shared
([
block_N
,
index_dim
],
dtype
)
index_k_scale_fragment
=
T
.
alloc_fragment
([
block_N
],
accum_dtype
)
...
...
@@ -140,17 +137,14 @@ def mqa_attn_return_logits(
cu_k_e_max
[
0
]
=
-
2147483648
for
bq_i
in
T
.
serial
(
block_Q
):
cu_k_s_min
[
0
]
=
T
.
min
(
cu_k_s_min
[
0
],
T
.
min
(
CuSeqLenKS
[
seq_len_i
+
bq_i
],
seq_len_kv
))
cu_k_s_min
[
0
]
=
T
.
min
(
cu_k_s_min
[
0
],
T
.
min
(
CuSeqLenKS
[
seq_len_i
+
bq_i
],
seq_len_kv
))
for
bq_i
in
T
.
serial
(
block_Q
):
cu_k_e_max
[
0
]
=
T
.
max
(
cu_k_e_max
[
0
],
T
.
min
(
CuSeqLenKE
[
seq_len_i
+
bq_i
],
seq_len_kv
))
cu_k_e_max
[
0
]
=
T
.
max
(
cu_k_e_max
[
0
],
T
.
min
(
CuSeqLenKE
[
seq_len_i
+
bq_i
],
seq_len_kv
))
T
.
copy
(
IndexQ
[
seq_len_i
*
heads
,
0
],
index_q_shared
)
T
.
copy
(
Weights
[
seq_len_i
,
0
],
weights
)
for
nbn_i
in
T
.
Pipelined
(
T
.
ceildiv
(
cu_k_e_max
[
0
]
-
cu_k_s_min
[
0
],
block_N
),
num_stages
=
num_stages
):
for
nbn_i
in
T
.
Pipelined
(
T
.
ceildiv
(
cu_k_e_max
[
0
]
-
cu_k_s_min
[
0
],
block_N
),
num_stages
=
num_stages
):
T
.
copy
(
IndexK
[
cu_k_s_min
[
0
]
+
nbn_i
*
block_N
,
0
],
index_k_shared
)
T
.
copy
(
IndexKScale
[
cu_k_s_min
[
0
]
+
nbn_i
*
block_N
],
index_k_scale_fragment
)
...
...
@@ -164,15 +158,14 @@ def mqa_attn_return_logits(
)
for
bn_i
,
bq_i
,
h_i
in
T
.
Parallel
(
block_N
,
block_Q
,
heads
):
s_reshaped
[
bn_i
,
bq_i
,
h_i
]
=
(
T
.
max
(
s_reshaped
[
bn_i
,
bq_i
,
h_i
],
0
)
*
weights
[
bq_i
,
h_i
])
*
index_k_scale_fragment
[
bn_i
]
s_reshaped
[
bn_i
,
bq_i
,
h_i
]
=
(
T
.
max
(
s_reshaped
[
bn_i
,
bq_i
,
h_i
],
0
)
*
weights
[
bq_i
,
h_i
])
*
index_k_scale_fragment
[
bn_i
]
T
.
reduce_sum
(
s_reshaped
,
logits
,
dim
=-
1
,
clear
=
True
)
for
bq_i
,
bn_i
in
T
.
Parallel
(
block_Q
,
block_N
):
Logits
[
seq_len_i
+
bq_i
,
cu_k_s_min
[
0
]
+
nbn_i
*
block_N
+
bn_i
]
=
(
logits
[
bn_i
,
bq_i
])
Logits
[
seq_len_i
+
bq_i
,
cu_k_s_min
[
0
]
+
nbn_i
*
block_N
+
bn_i
]
=
logits
[
bn_i
,
bq_i
]
return
mqa_attn_return_logits_kernel
...
...
@@ -190,9 +183,9 @@ def clean_logits_(
@
T
.
prim_func
def
clean_logits_kernel
(
Logits
:
T
.
Tensor
([
seq_len
,
seq_len_kv
],
dtype
),
# type: ignore
CuSeqLenKS
:
T
.
Tensor
([
seq_len
],
indices_dtype
),
# type: ignore
CuSeqLenKE
:
T
.
Tensor
([
seq_len
],
indices_dtype
),
# type: ignore
Logits
:
T
.
Tensor
([
seq_len
,
seq_len_kv
],
dtype
),
# type: ignore
CuSeqLenKS
:
T
.
Tensor
([
seq_len
],
indices_dtype
),
# type: ignore
CuSeqLenKE
:
T
.
Tensor
([
seq_len
],
indices_dtype
),
# type: ignore
):
with
T
.
Kernel
(
seq_len
,
threads
=
threads
)
as
bx
:
tx
=
T
.
thread_binding
(
0
,
threads
,
thread
=
"threadIdx.x"
)
...
...
@@ -210,13 +203,7 @@ def clean_logits_(
return
clean_logits_kernel
def
mqa_attn_return_logits_interface
(
q
,
kv
,
kv_scales
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
,
clean_logits
=
True
):
def
mqa_attn_return_logits_interface
(
q
,
kv
,
kv_scales
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
,
clean_logits
=
True
):
seq_len
,
heads
,
index_dim
=
q
.
shape
seq_len_kv
=
kv
.
shape
[
0
]
...
...
@@ -238,20 +225,19 @@ def mqa_attn_return_logits_interface(q,
return
logits
def
ref_fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
):
def
ref_fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
):
k
=
kv
q
=
q
.
float
()
k
=
k
.
float
()
seq_len_kv
=
kv
.
shape
[
0
]
mask_lo
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
'
cuda
'
)[
None
,
:]
>=
cu_seqlen_ks
[:,
None
]
mask_hi
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
'
cuda
'
)[
None
,
:]
<
cu_seqlen_ke
[:,
None
]
mask_lo
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"
cuda
"
)[
None
,
:]
>=
cu_seqlen_ks
[:,
None
]
mask_hi
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"
cuda
"
)[
None
,
:]
<
cu_seqlen_ke
[:,
None
]
mask
=
mask_lo
&
mask_hi
score
=
torch
.
einsum
(
'
mhd,nd->hmn
'
,
q
,
k
)
score
=
torch
.
einsum
(
"
mhd,nd->hmn
"
,
q
,
k
)
logits
=
(
score
.
relu
()
*
weights
.
unsqueeze
(
-
1
).
transpose
(
0
,
1
)).
sum
(
dim
=
0
)
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
'
-inf
'
))
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"
-inf
"
))
cost
=
mask
.
sum
()
return
logits
,
cost
...
...
@@ -265,32 +251,22 @@ def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
weights
=
torch
.
randn
(
S
,
H
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
p
=
(
torch
.
randn
(
S
,
SKV
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
4
).
softmax
(
dim
=-
1
)
ks
,
ke
=
generate_random_cu_seqlens
(
per_cp_seqlen
=
S
,
cp_size
=
4
,
cp_rank
=
3
,
kv_stride
=
kv_stride
,
average_q_len
=
2048
)
ks
,
ke
=
generate_random_cu_seqlens
(
per_cp_seqlen
=
S
,
cp_size
=
4
,
cp_rank
=
3
,
kv_stride
=
kv_stride
,
average_q_len
=
2048
)
logits_ref
,
cost_ref
=
ref_fp8_mqa_logits
(
q
=
q
,
kv
=
kv
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
logits_ref
,
cost_ref
=
ref_fp8_mqa_logits
(
q
=
q
,
kv
=
kv
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
q_fp8
=
q
.
to
(
torch
.
float8_e4m3fn
)
kv_fp8
,
kv_scales
=
per_custom_dims_cast_to_fp8
(
kv
,
(
0
,),
False
)
logits_tl
=
mqa_attn_return_logits_interface
(
q
=
q_fp8
,
kv
=
kv_fp8
,
kv_scales
=
kv_scales
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
diff
=
validate_tensor_match
(
logits_ref
,
logits_tl
,
tolerance
=
1e-14
,
tensor_name
=
"logits"
,
should_raise
=
False
)
logits_tl
=
mqa_attn_return_logits_interface
(
q
=
q_fp8
,
kv
=
kv_fp8
,
kv_scales
=
kv_scales
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
diff
=
validate_tensor_match
(
logits_ref
,
logits_tl
,
tolerance
=
1e-14
,
tensor_name
=
"logits"
,
should_raise
=
False
)
print
(
f
"diff:
{
diff
}
"
)
from
tilelang.profiler
import
do_bench
def
logits_fn
():
return
mqa_attn_return_logits_interface
(
q
=
q_fp8
,
kv
=
kv_fp8
,
kv_scales
=
kv_scales
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
return
mqa_attn_return_logits_interface
(
q
=
q_fp8
,
kv
=
kv_fp8
,
kv_scales
=
kv_scales
,
weights
=
weights
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
])
as
prof
:
logits_fn
()
...
...
Prev
1
2
3
4
5
6
7
8
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment