Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
b6798030
Commit
b6798030
authored
Feb 25, 2025
by
yangsijia.614
Browse files
fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them
parent
4edea86f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
benchmark/bench_flash_mla.py
benchmark/bench_flash_mla.py
+12
-6
benchmark/visualize.py
benchmark/visualize.py
+12
-2
No files found.
benchmark/bench_flash_mla.py
View file @
b6798030
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import
argparse
import
math
import
random
import
flashinfer
import
torch
import
triton
import
triton.language
as
tl
import
argparse
# pip install flashinfer-python
from
flash_mla
import
get_mla_metadata
,
flash_mla_with_kvcache
import
flashinfer
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
):
query
=
query
.
float
()
...
...
@@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"perf
{
baseline
}
:
{
perf_a
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_a
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_a
:.
0
f
}
GB/s"
)
print
(
f
"perf
{
target
}
:
{
perf_b
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
perf_b
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
perf_b
:.
0
f
}
GB/s"
)
return
bytes
/
10
**
6
/
perf_a
,
bytes
/
10
**
6
/
perf_b
def
compare_a
(
target
,
b
,
s_q
,
cache_seqlens
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
dtype
):
...
...
@@ -501,7 +503,8 @@ def get_args():
if
__name__
==
"__main__"
:
args
=
get_args
()
with
open
(
"all_perf.csv"
,
"w"
)
as
fout
:
benchmark_type
=
"all"
if
args
.
all
else
f
"
{
args
.
baseline
}
_vs_
{
args
.
target
}
"
if
args
.
compare
else
args
.
target
with
open
(
f
"
{
benchmark_type
}
_perf.csv"
,
"w"
)
as
fout
:
fout
.
write
(
"name,batch,seqlen,head,bw
\n
"
)
for
shape
in
shape_configs
:
if
args
.
all
:
...
...
@@ -509,6 +512,9 @@ if __name__ == "__main__":
perf
=
compare_a
(
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
fout
.
write
(
f
'
{
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"cache_seqlens"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"h_q"
]
}
,
{
perf
:.
0
f
}
\n
'
)
elif
args
.
compare
:
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
perfa
,
prefb
=
compare_ab
(
args
.
baseline
,
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
fout
.
write
(
f
'
{
args
.
baseline
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"cache_seqlens"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"h_q"
]
}
,
{
perfa
:.
0
f
}
\n
'
)
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"cache_seqlens"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"h_q"
]
}
,
{
prefb
:.
0
f
}
\n
'
)
elif
args
.
one
:
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
\ No newline at end of file
perf
=
compare_a
(
args
.
target
,
shape
[
"b"
],
shape
[
"s_q"
],
shape
[
"cache_seqlens"
],
shape
[
"h_q"
],
shape
[
"h_kv"
],
shape
[
"d"
],
shape
[
"dv"
],
shape
[
"causal"
],
shape
[
"dtype"
])
fout
.
write
(
f
'
{
args
.
target
}
,
{
shape
[
"b"
]
}
,
{
shape
[
"cache_seqlens"
].
float
().
mean
().
cpu
().
item
():.
0
f
}
,
{
shape
[
"h_q"
]
}
,
{
perf
:.
0
f
}
\n
'
)
\ No newline at end of file
benchmark/visualize.py
View file @
b6798030
import
argparse
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
file_path
=
'all_perf.csv'
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Visualize benchmark results'
)
parser
.
add_argument
(
'--file'
,
type
=
str
,
default
=
'all_perf.csv'
,
help
=
'Path to the CSV file with benchmark results (default: all_perf.csv)'
)
return
parser
.
parse_args
()
args
=
parse_args
()
file_path
=
args
.
file
df
=
pd
.
read_csv
(
file_path
)
...
...
@@ -16,4 +26,4 @@ plt.xlabel('seqlen')
plt
.
ylabel
(
'bw (GB/s)'
)
plt
.
legend
()
plt
.
savefig
(
'bandwidth_vs_seqlen.png'
)
\ No newline at end of file
plt
.
savefig
(
f
'
{
file_path
.
split
(
"."
)[
0
].
split
(
"/"
)[
-
1
]
}
_bandwidth_vs_seqlen.png'
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment