Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4b189c12
Commit
4b189c12
authored
Oct 18, 2023
by
Alan Turner
Browse files
Add flag to use exclusively CK gemms
parent
70b7a68f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-1
tools/gemm_perf.py
tools/gemm_perf.py
+3
-2
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
4b189c12
...
@@ -34,6 +34,8 @@ struct module;
...
@@ -34,6 +34,8 @@ struct module;
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_USE_CK_ONLY
);
struct
ck_gemm
struct
ck_gemm
{
{
operation
op
=
make_op
(
"dot"
);
operation
op
=
make_op
(
"dot"
);
...
@@ -106,7 +108,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -106,7 +108,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
// To-do: Investigate a more precise strategy
return
k
<=
2048
;
return
k
<=
2048
or
enabled
(
MIGRAPHX_USE_CK_ONLY
{})
;
}
}
struct
find_ck_gemm_pointwise
struct
find_ck_gemm_pointwise
...
...
tools/gemm_perf.py
View file @
4b189c12
...
@@ -115,9 +115,10 @@ def get_gemm_time(config, fp16, provider, timeout):
...
@@ -115,9 +115,10 @@ def get_gemm_time(config, fp16, provider, timeout):
env
=
dict
(
os
.
environ
,
env
=
dict
(
os
.
environ
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_CK
=
use_CK
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_USE_CK_ONLY
=
"1"
,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"
{
provider
.
name
}
encountered an
d
exception
{
e
}
"
)
print
(
f
"
{
provider
.
name
}
encountered an exception
:
{
e
}
"
)
return
-
100.0
return
-
100.0
if
verify_output
(
str
(
out
.
stdout
),
provider
):
if
verify_output
(
str
(
out
.
stdout
),
provider
):
...
@@ -147,7 +148,7 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
...
@@ -147,7 +148,7 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_ENABLE_MLIR
=
use_MLIR
,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
=
"dot"
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"
{
provider
.
name
}
encountered an
d
exception
{
e
}
"
)
print
(
f
"
{
provider
.
name
}
encountered an exception
:
{
e
}
"
)
return
-
100.0
return
-
100.0
if
verify_output
(
str
(
out
.
stdout
),
provider
):
if
verify_output
(
str
(
out
.
stdout
),
provider
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment