Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
70b7a68f
Commit
70b7a68f
authored
Oct 16, 2023
by
Alan Turner
Browse files
Further improvements
parent
6cb25a15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
21 deletions
+25
-21
tools/gemm_perf.py
tools/gemm_perf.py
+25
-21
No files found.
tools/gemm_perf.py
View file @
70b7a68f
#%matplotlib
import
subprocess
,
csv
,
re
,
datetime
,
argparse
,
os
,
enum
import
subprocess
,
csv
,
re
,
datetime
,
argparse
,
os
,
enum
from
subprocess
import
STDOUT
def
parse_args
():
def
parse_args
():
...
@@ -93,11 +91,11 @@ def get_total_time(output):
...
@@ -93,11 +91,11 @@ def get_total_time(output):
def
verify_output
(
output
,
provider
):
def
verify_output
(
output
,
provider
):
summary
=
re
.
findall
(
"Summary.*"
,
output
)[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
summary
=
re
.
findall
(
"Summary.*"
,
output
)[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
if
provider
==
GEMM_Provider
.
CK
:
if
provider
==
GEMM_Provider
.
CK
:
return
bool
(
re
.
mat
ch
(
".*ck_gemm.*"
,
summary
))
return
bool
(
re
.
sear
ch
(
".*ck_gemm.*"
,
summary
))
if
provider
==
GEMM_Provider
.
ROCBLAS
:
if
provider
==
GEMM_Provider
.
ROCBLAS
:
return
bool
(
re
.
mat
ch
(
".*gpu::gemm.*"
,
summary
))
return
bool
(
re
.
search
(
".*gpu::gemm.*"
,
summary
))
or
bool
(
re
.
sear
ch
(
".*gpu::
quant_
gemm.*"
,
summary
))
if
provider
==
GEMM_Provider
.
MLIR
:
if
provider
==
GEMM_Provider
.
MLIR
:
return
bool
(
re
.
match
(
".*MLIR
.*"
,
summary
))
#######
return
bool
(
re
.
search
(
".*mlir_dot.*"
,
summary
))
or
bool
(
re
.
search
(
".*mlir_quant_dot
.*"
,
summary
))
def
get_gemm_time
(
config
,
fp16
,
provider
,
timeout
):
def
get_gemm_time
(
config
,
fp16
,
provider
,
timeout
):
...
@@ -116,15 +114,19 @@ def get_gemm_time(config, fp16, provider, timeout):
...
@@ -116,15 +114,19 @@ def get_gemm_time(config, fp16, provider, timeout):
timeout
=
timeout
,
timeout
=
timeout
,
env
=
dict
(
os
.
environ
,
env
=
dict
(
os
.
environ
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
))
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"
{
provider
.
name
}
encountered and exception
{
e
}
"
)
return
-
100.0
return
-
100.0
verify_output
(
str
(
out
.
stdout
),
provider
)
if
verify_output
(
str
(
out
.
stdout
),
provider
):
total_time
=
get_total_time
(
str
(
out
.
stdout
))
total_time
=
get_total_time
(
str
(
out
.
stdout
))
print
(
f
"
{
provider
.
name
}
finished in
{
total_time
}
"
)
print
(
f
"
{
provider
.
name
}
total time:
{
total_time
}
ms"
)
return
total_time
return
total_time
else
:
print
(
f
"
{
provider
.
name
}
was not found in performance summary"
)
return
-
100.0
def
get_gemm_softmax_gemm_time
(
config
,
provider
,
timeout
):
def
get_gemm_softmax_gemm_time
(
config
,
provider
,
timeout
):
...
@@ -142,17 +144,19 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
...
@@ -142,17 +144,19 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
timeout
=
timeout
,
timeout
=
timeout
,
env
=
dict
(
os
.
environ
,
env
=
dict
(
os
.
environ
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
))
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"
{
provider
.
name
}
encountered and exception
{
e
}
"
)
return
-
100.0
return
-
100.0
summary
=
re
.
findall
(
"Summary.*"
,
str
(
out
.
stdout
))[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
if
verify_output
(
str
(
out
.
stdout
),
provider
):
total_time
=
re
.
findall
(
"T
otal
time
: \d+\.\d*"
,
summary
)[
0
]
total_time
=
get_t
otal
_
time
(
str
(
out
.
stdout
))
total_time
=
total
_
time
.
replace
(
"T
otal
time
: "
,
"
"
)
print
(
f
"
{
provider
.
name
}
total
time
:
{
t
otal
_
time
}
ms
"
)
total_time
=
float
(
total_time
)
return
total_time
print
(
f
"
{
provider
.
name
}
finished in
{
total_time
}
"
)
else
:
print
(
f
"
{
provider
.
name
}
was not found in performance summary"
)
return
total_time
return
-
100.0
def
run_gemm_perf
(
batches
,
sizes
,
fp16
,
timeout
):
def
run_gemm_perf
(
batches
,
sizes
,
fp16
,
timeout
):
...
@@ -163,7 +167,7 @@ def run_gemm_perf(batches, sizes, fp16, timeout):
...
@@ -163,7 +167,7 @@ def run_gemm_perf(batches, sizes, fp16, timeout):
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
for
shape
in
sizes
:
for
shape
in
sizes
:
config
=
(
b
,)
+
shape
config
=
(
b
,)
+
shape
print
(
"Running gemm with config: {0}, {1}, {2}"
.
format
(
*
config
))
print
(
"Running
{prec}
gemm with config: {0}, {1}, {2}
, {3}
"
.
format
(
prec
=
prec_str
,
*
config
))
ck_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
CK
,
timeout
)
ck_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
CK
,
timeout
)
rb_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
rb_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
mlir_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
MLIR
,
timeout
)
mlir_time
=
get_gemm_time
(
config
,
fp16
,
GEMM_Provider
.
MLIR
,
timeout
)
...
@@ -177,7 +181,7 @@ def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
...
@@ -177,7 +181,7 @@ def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"o"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
out
.
write_row
([
"batch_size"
,
"m"
,
"n"
,
"k"
,
"o"
,
"CK Total Time (ms)"
,
"rocBLAS Total Time (ms)"
,
"MLIR Total Time (ms)"
])
for
shape
in
sizes
:
for
shape
in
sizes
:
config
=
(
b
,)
+
shape
config
=
(
b
,)
+
shape
print
(
"Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}"
.
format
(
*
config
))
print
(
"Running
fp16
gemm-softmax-gemm with config: {0}, {1}, {2}, {3}
, {4}
"
.
format
(
*
config
))
ck_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
CK
,
timeout
)
ck_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
CK
,
timeout
)
rb_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
rb_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
ROCBLAS
,
timeout
)
mlir_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
MLIR
,
timeout
)
mlir_time
=
get_gemm_softmax_gemm_time
(
config
,
GEMM_Provider
.
MLIR
,
timeout
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment