Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
166a9585
Commit
166a9585
authored
Mar 08, 2025
by
You Jiacheng
Committed by
GitHub
Mar 08, 2025
Browse files
[Dev] Use SS-GEMM for PV in mla (#165)
It's slightly faster than T.copy then RS-GEMM, and simpler.
parent
d3f26ef8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
5 deletions
+2
-5
examples/deepseek_mla/example_mla_decode.py
examples/deepseek_mla/example_mla_decode.py
+2
-5
No files found.
examples/deepseek_mla/example_mla_decode.py
View file @
166a9585
...
@@ -31,7 +31,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -31,7 +31,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
K_pe_shared
=
T
.
alloc_shared
([
block_N
,
pe_dim
],
dtype
)
K_pe_shared
=
T
.
alloc_shared
([
block_N
,
pe_dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
...
@@ -43,7 +42,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -43,7 +42,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
S_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
S_shared
),
})
})
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bx
,
by
*
VALID_BLOCK_H
:(
by
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
...
@@ -74,12 +72,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -74,12 +72,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
copy
(
acc_s
,
S_shared
)
T
.
copy
(
acc_s
,
S_shared
)
T
.
copy
(
S_shared
,
acc_s_cast
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
S_shared
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
...
@@ -297,4 +294,4 @@ if __name__ == "__main__":
...
@@ -297,4 +294,4 @@ if __name__ == "__main__":
print
(
"All close"
)
print
(
"All close"
)
latency
=
mod
.
do_bench
(
mod
.
func
,
n_warmup
=
10
,
n_repeat
=
10
,
profiler
=
"torch"
)
latency
=
mod
.
do_bench
(
mod
.
func
,
n_warmup
=
10
,
n_repeat
=
10
,
profiler
=
"torch"
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment