Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
16561159
Unverified
Commit
16561159
authored
Sep 29, 2025
by
Wenxuan Tan
Committed by
GitHub
Sep 30, 2025
Browse files
[Bugfix] Fix flops comp and softmax scale in mla (#900)
* fix flops comp and softmax scale * format
parent
54fc6ba0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
14 deletions
+25
-14
examples/deepseek_mla/benchmark_mla.py
examples/deepseek_mla/benchmark_mla.py
+10
-10
examples/deepseek_mla/example_mla_decode_paged.py
examples/deepseek_mla/example_mla_decode_paged.py
+15
-4
No files found.
examples/deepseek_mla/benchmark_mla.py
View file @
16561159
...
@@ -87,7 +87,7 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
...
@@ -87,7 +87,7 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_flash
_
infer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
def
run_flashinfer
(
q
,
block_table
,
blocked_k
,
max_seqlen_pad
,
block_size
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
# pip install flashinfer-python
# pip install flashinfer-python
import
flashinfer
import
flashinfer
...
@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
...
@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
blocked_k
.
dtype
,
blocked_k
.
dtype
,
)
)
def
flash
_
infer
():
def
flashinfer
():
output
,
lse
=
mla_wrapper
.
run
(
output
,
lse
=
mla_wrapper
.
run
(
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_nope
.
view
(
-
1
,
h_q
,
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
q_pe
.
view
(
-
1
,
h_q
,
d
-
dv
),
...
@@ -137,8 +137,8 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
...
@@ -137,8 +137,8 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
return_lse
=
True
)
return_lse
=
True
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
return
output
.
view
(
b
,
-
1
,
h_q
,
dv
),
lse
.
view
(
b
,
h_q
,
1
)
out_flash
,
lse_flash
=
flash
_
infer
()
out_flash
,
lse_flash
=
flashinfer
()
t
=
triton
.
testing
.
do_bench
(
flash
_
infer
)
t
=
triton
.
testing
.
do_bench
(
flashinfer
)
return
out_flash
,
lse_flash
,
t
return
out_flash
,
lse_flash
,
t
...
@@ -459,7 +459,7 @@ FUNC_TABLE = {
...
@@ -459,7 +459,7 @@ FUNC_TABLE = {
"torch"
:
run_torch_mla
,
"torch"
:
run_torch_mla
,
"tilelang"
:
run_flash_mla_tilelang
,
"tilelang"
:
run_flash_mla_tilelang
,
"flash_mla"
:
run_flash_mla
,
"flash_mla"
:
run_flash_mla
,
"flash
_
infer"
:
run_flash
_
infer
,
"flashinfer"
:
run_flashinfer
,
"flash_mla_triton"
:
run_flash_mla_triton
,
"flash_mla_triton"
:
run_flash_mla_triton
,
}
}
...
@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
...
@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
)
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
torch
.
testing
.
assert_close
(
out_b
.
float
(),
out_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"out"
if
target
not
in
[
"flash
_
infer"
,
"flash_mla_triton"
,
"tilelang"
if
target
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]
and
baseline
not
in
[
"flash
_
infer"
,
"flash_mla_triton"
,
"tilelang"
]:
]
and
baseline
not
in
[
"flashinfer"
,
"flash_mla_triton"
,
"tilelang"
]:
# flash
_
infer has a different lse return value
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
torch
.
testing
.
assert_close
(
lse_b
.
float
(),
lse_a
.
float
(),
atol
=
1e-2
,
rtol
=
1e-2
),
"lse"
...
@@ -554,7 +554,7 @@ available_targets = [
...
@@ -554,7 +554,7 @@ available_targets = [
"torch"
,
"torch"
,
"tilelang"
,
"tilelang"
,
"flash_mla"
,
"flash_mla"
,
"flash
_
infer"
,
"flashinfer"
,
"flash_mla_triton"
,
"flash_mla_triton"
,
]
]
...
...
examples/deepseek_mla/example_mla_decode_paged.py
View file @
16561159
...
@@ -11,8 +11,19 @@ import math
...
@@ -11,8 +11,19 @@ import math
out_idx
=
[
8
],
pass_configs
=
{
out_idx
=
[
8
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
})
def
mla_decode_tilelang
(
batch
,
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
def
mla_decode_tilelang
(
batch
,
block_size
,
softmax_scale
):
h_q
,
h_kv
,
max_seqlen_pad
,
dv
,
dpe
,
block_N
,
block_H
,
num_split
,
block_size
,
softmax_scale
=
None
):
if
softmax_scale
is
None
:
softmax_scale
=
(
dv
+
dpe
)
**-
0.5
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
scale
=
float
(
softmax_scale
*
1.44269504
)
# log2(e)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
...
@@ -322,7 +333,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
...
@@ -322,7 +333,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
num_kv_splits
=
1
num_kv_splits
=
1
BLOCK_N
=
64
BLOCK_N
=
64
BLOCK_H
=
min
(
64
,
h_q
//
h_kv
)
BLOCK_H
=
min
(
64
,
h_q
//
h_kv
)
softmax_scale
=
(
d
+
dv
)
**-
0.5
softmax_scale
=
d
**-
0.5
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
out_partial
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dv
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
glse
=
torch
.
empty
(
b
,
h_q
,
num_kv_splits
,
dtype
=
dtype
,
device
=
q
.
device
)
...
@@ -379,7 +390,7 @@ if __name__ == "__main__":
...
@@ -379,7 +390,7 @@ if __name__ == "__main__":
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
math
.
ceil
(
max_seqlen
/
256
)
*
256
max_seqlen_pad
=
math
.
ceil
(
max_seqlen
/
256
)
*
256
total_flops
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
total_flops
=
s_q
*
total_seqlens
*
h_q
*
d
*
2
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
,
dtype
=
dtype
,
device
=
device
)
block_table
=
torch
.
arange
(
block_table
=
torch
.
arange
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment