Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6cb25a15
Commit
6cb25a15
authored
Oct 16, 2023
by
turneram
Browse files
Add improvements
parent
7a80bd37
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
32 deletions
+80
-32
tools/gemm_perf.py
tools/gemm_perf.py
+80
-32
No files found.
tools/gemm_perf.py
View file @
6cb25a15
...
@@ -3,12 +3,6 @@ import subprocess, csv, re, datetime, argparse, os, enum
...
@@ -3,12 +3,6 @@ import subprocess, csv, re, datetime, argparse, os, enum
from
subprocess
import
STDOUT
from
subprocess
import
STDOUT
class
GEMM_Provider
(
enum
.
Enum
):
CK
=
1
ROCBLAS
=
2
MLIR
=
3
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"GEMM performance tools"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"GEMM performance tools"
)
parser
.
add_argument
(
'--gemm'
,
parser
.
add_argument
(
'--gemm'
,
...
@@ -28,7 +22,15 @@ def parse_args():
...
@@ -28,7 +22,15 @@ def parse_args():
parser
.
add_argument
(
'--exclude-lens'
,
parser
.
add_argument
(
'--exclude-lens'
,
'-e'
,
'-e'
,
nargs
=
'+'
,
nargs
=
'+'
,
help
=
'Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]'
)
help
=
'Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]
\
Lengths not excluded will be permuted as m, n, k, (o) inputs'
)
parser
.
add_argument
(
'--lens-from-file'
,
'-l'
,
help
=
'Run a list of problem lens from file containing rows of
\
m0 n0 k0 (o0)
\
m1 n1 k1 (o1)
\
...
\
(oi) only used for gemm-softmax-gemm'
)
parser
.
add_argument
(
'--timeout'
,
parser
.
add_argument
(
'--timeout'
,
'-t'
,
'-t'
,
type
=
int
,
type
=
int
,
...
@@ -50,6 +52,16 @@ class CSVFile:
...
@@ -50,6 +52,16 @@ class CSVFile:
cw
.
writerow
(
row
)
cw
.
writerow
(
row
)
class
GEMM_Provider
(
enum
.
Enum
):
CK
=
1
ROCBLAS
=
2
MLIR
=
3
def
get_migraphx_root
():
return
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
def
get_device_name
():
def
get_device_name
():
out
=
subprocess
.
run
(
"rocminfo"
,
out
=
subprocess
.
run
(
"rocminfo"
,
capture_output
=
True
,
capture_output
=
True
,
...
@@ -59,11 +71,41 @@ def get_device_name():
...
@@ -59,11 +71,41 @@ def get_device_name():
return
matches
[
0
]
return
matches
[
0
]
def
verify_format
(
file
,
single
):
if
not
file
:
return
False
format
=
r
"^\d\s\d\s\d\s*$"
if
single
else
r
"^\d\s\d\s\d\s\d\s*$"
with
open
(
file
,
'r'
)
as
f
:
return
all
([
bool
(
re
.
match
(
format
,
line
))
for
line
in
f
.
readlines
()])
def
parse_lens
(
file
):
with
open
(
file
,
'r'
)
as
f
:
return
[
tuple
(
map
(
int
,
line
.
split
()))
for
line
in
f
.
readlines
()]
def
get_total_time
(
output
):
summary
=
re
.
findall
(
"Summary.*"
,
output
)[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
total_time
=
re
.
findall
(
"Total time: \d+\.\d*"
,
summary
)[
0
]
return
float
(
total_time
.
replace
(
"Total time: "
,
""
))
def
verify_output
(
output
,
provider
):
summary
=
re
.
findall
(
"Summary.*"
,
output
)[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
if
provider
==
GEMM_Provider
.
CK
:
return
bool
(
re
.
match
(
".*ck_gemm.*"
,
summary
))
if
provider
==
GEMM_Provider
.
ROCBLAS
:
return
bool
(
re
.
match
(
".*gpu::gemm.*"
,
summary
))
if
provider
==
GEMM_Provider
.
MLIR
:
return
bool
(
re
.
match
(
".*MLIR.*"
,
summary
))
#######
def
get_gemm_time
(
config
,
fp16
,
provider
,
timeout
):
def
get_gemm_time
(
config
,
fp16
,
provider
,
timeout
):
model
=
"../test/onnx/matmul_half.onnx"
root
=
get_migraphx_root
()
model
=
f
"
{
root
}
/test/onnx/matmul_half.onnx"
b
,
m
,
n
,
k
=
config
b
,
m
,
n
,
k
=
config
prec_str
=
"--fp16"
if
fp16
else
"--int8"
prec_str
=
"--fp16"
if
fp16
else
"--int8"
cmd
=
f
"
..
/build/bin/driver perf
{
model
}
--input-dim @1
{
b
}
{
m
}
{
k
}
@2
{
b
}
{
k
}
{
n
}
{
prec_str
}
"
cmd
=
f
"
{
root
}
/build/bin/driver perf
{
model
}
--input-dim @1
{
b
}
{
m
}
{
k
}
@2
{
b
}
{
k
}
{
n
}
{
prec_str
}
--exhaustive-tune
"
use_CK
=
"1"
if
provider
==
GEMM_Provider
.
CK
else
"0"
use_CK
=
"1"
if
provider
==
GEMM_Provider
.
CK
else
"0"
use_MLIR
=
"1"
if
provider
==
GEMM_Provider
.
MLIR
else
"0"
use_MLIR
=
"1"
if
provider
==
GEMM_Provider
.
MLIR
else
"0"
...
@@ -76,22 +118,20 @@ def get_gemm_time(config, fp16, provider, timeout):
...
@@ -76,22 +118,20 @@ def get_gemm_time(config, fp16, provider, timeout):
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
))
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An exception occurred:
{
str
(
e
)
}
"
)
print
(
f
"
{
provider
.
name
}
GEMM
{
b
}
,
{
m
}
,
{
n
}
,
{
k
}
:"
,
end
=
" "
)
print
(
"-100.0"
)
return
-
100.0
return
-
100.0
summary
=
re
.
findall
(
"Summary.*"
,
str
(
out
.
stdout
))[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
verify_output
(
str
(
out
.
stdout
),
provider
)
total_time
=
re
.
findall
(
"T
otal
time
: \d+\.\d*"
,
summary
)[
0
]
total_time
=
get_t
otal
_
time
(
str
(
out
.
stdout
))
total_time
=
total_time
.
replace
(
"Total time: "
,
"
"
)
print
(
f
"
{
provider
.
name
}
finished in
{
total_time
}
"
)
return
float
(
total_time
)
return
total_time
def
get_gemm_softmax_gemm_time
(
config
,
provider
,
timeout
):
def
get_gemm_softmax_gemm_time
(
config
,
provider
,
timeout
):
model
=
"../test/onnx/gemm_softmax_gemm_half.onnx"
root
=
get_migraphx_root
()
model
=
f
"
{
root
}
/test/onnx/gemm_softmax_gemm_half.onnx"
b
,
m
,
n
,
k
,
o
=
config
b
,
m
,
n
,
k
,
o
=
config
cmd
=
f
"
..
/build/bin/driver perf
{
model
}
--input-dim @a
{
b
}
{
m
}
{
k
}
@b
{
b
}
{
k
}
{
n
}
@b1
{
b
}
{
n
}
{
o
}
--fp16"
cmd
=
f
"
{
root
}
/build/bin/driver perf
{
model
}
--input-dim @a
{
b
}
{
m
}
{
k
}
@b
{
b
}
{
k
}
{
n
}
@b1
{
b
}
{
n
}
{
o
}
--fp16
--exhaustive-tune
"
use_CK
=
"1"
if
provider
==
GEMM_Provider
.
CK
else
"0"
use_CK
=
"1"
if
provider
==
GEMM_Provider
.
CK
else
"0"
use_MLIR
=
"1"
if
provider
==
GEMM_Provider
.
MLIR
else
"0"
use_MLIR
=
"1"
if
provider
==
GEMM_Provider
.
MLIR
else
"0"
...
@@ -109,34 +149,38 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
...
@@ -109,34 +149,38 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
summary
=
re
.
findall
(
"Summary.*"
,
str
(
out
.
stdout
))[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
summary
=
re
.
findall
(
"Summary.*"
,
str
(
out
.
stdout
))[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
total_time
=
re
.
findall
(
"Total time: \d+\.\d*"
,
summary
)[
0
]
total_time
=
re
.
findall
(
"Total time: \d+\.\d*"
,
summary
)[
0
]
total_time
=
total_time
.
replace
(
"Total time: "
,
""
)
total_time
=
total_time
.
replace
(
"Total time: "
,
""
)
total_time
=
float
(
total_time
)
print
(
f
"
{
provider
.
name
}
finished in
{
total_time
}
"
)
return
float
(
total_time
)
return
total_time
def
run_gemm_perf
(
batches
,
sizes
,
fp16
):
def
run_gemm_perf
(
batches
,
sizes
,
fp16
,
timeout
):
prec_str
=
"
half
"
if
fp16
else
"int"
prec_str
=
"
fp16
"
if
fp16
else
"int
8
"
for
b
in
batches
:
for
b
in
batches
:
out
=
CSVFile
(
f
"gemm_perf_
{
prec_str
}
_
{
b
}
.csv"
)
out
=
CSVFile
(
f
"gemm_perf_
{
prec_str
}
_
{
b
}
.csv"
)
out
.
write_row
([
get_device_name
(),
datetime
.
datetime
.
now
()])
out
.
write_row
([
get_device_name
(),
datetime
.
datetime
.
now
()])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
for
shape
in
[(
m
,
n
,
k
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
]
:
for
shape
in
sizes
:
config
=
(
b
,)
+
shape
config
=
(
b
,)
+
shape
ck_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
CK
)
print
(
"Running gemm with config: {0}, {1}, {2}"
.
format
(
*
config
))
rb_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
ROCBLAS
)
ck_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
CK
,
timeout
)
mlir_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
MLIR
)
rb_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
mlir_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
MLIR
,
timeout
)
out
.
write_row
(
list
(
config
)
+
[
ck_time
,
rb_time
,
mlir_time
])
out
.
write_row
(
list
(
config
)
+
[
ck_time
,
rb_time
,
mlir_time
])
def
run_gemm_softmax_gemm_perf
(
batches
,
sizes
):
def
run_gemm_softmax_gemm_perf
(
batches
,
sizes
,
timeout
):
for
b
in
batches
:
for
b
in
batches
:
out
=
CSVFile
(
f
"gemm_softmax_gemm_perf_fp16_
{
b
}
.csv"
)
out
=
CSVFile
(
f
"gemm_softmax_gemm_perf_fp16_
{
b
}
.csv"
)
out
.
write_row
([
get_device_name
(),
datetime
.
datetime
.
now
()])
out
.
write_row
([
get_device_name
(),
datetime
.
datetime
.
now
()])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"o"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"o"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
for
shape
in
[(
m
,
n
,
k
,
o
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
for
o
in
sizes
]
:
for
shape
in
sizes
:
config
=
(
b
,)
+
shape
config
=
(
b
,)
+
shape
ck_time
=
get_gemm_time
(
config
,
GEMM_Provider
.
CK
)
print
(
"Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}"
.
format
(
*
config
))
rb_time
=
get_gemm_time
(
config
,
GEMM_Provider
.
ROCBLAS
)
ck_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
CK
,
timeout
)
mlir_time
=
get_gemm_time
(
config
,
GEMM_Provider
.
MLIR
)
rb_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
mlir_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
MLIR
,
timeout
)
out
.
write_row
(
list
(
config
)
+
[
ck_time
,
rb_time
,
mlir_time
])
out
.
write_row
(
list
(
config
)
+
[
ck_time
,
rb_time
,
mlir_time
])
...
@@ -150,7 +194,11 @@ if __name__ == "__main__":
...
@@ -150,7 +194,11 @@ if __name__ == "__main__":
timeout
=
int
(
args
.
timeout
)
timeout
=
int
(
args
.
timeout
)
if
args
.
gemm
:
if
args
.
gemm
:
gemm_sizes
=
[(
m
,
n
,
k
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
]
gemm_sizes
=
[(
m
,
n
,
k
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
]
run_gemm_perf
(
args
.
batches
,
gemm_sizes
,
fp16
,
timeout
)
if
verify_format
(
args
.
lens_from_file
,
True
):
gemm_sizes
=
parse_lens
(
args
.
lens_from_file
)
run_gemm_perf
(
args
.
batch_sizes
,
gemm_sizes
,
fp16
,
timeout
)
if
args
.
gemm_softmax_gemm
:
if
args
.
gemm_softmax_gemm
:
gemm_softmax_gemm_sizes
=
[(
m
,
n
,
k
,
o
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
for
o
in
sizes
]
gemm_softmax_gemm_sizes
=
[(
m
,
n
,
k
,
o
)
for
m
in
sizes
for
n
in
sizes
for
k
in
sizes
for
o
in
sizes
]
run_gemm_softmax_gemm_perf
(
args
.
batches
,
gemm_sizes
,
timeout
)
if
verify_format
(
args
.
lens_from_file
,
False
):
gemm_softmax_gemm_sizes
=
parse_lens
(
args
.
lens_from_file
)
run_gemm_softmax_gemm_perf
(
args
.
batch_sizes
,
gemm_softmax_gemm_sizes
,
timeout
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment