Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
297645e8
Commit
297645e8
authored
Mar 31, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/ck-gemm-fused-transpose' into ck-gsg
parents
ac7a0025
55b363c9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
27 deletions
+63
-27
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+11
-5
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+4
-5
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+4
-2
tools/tune_ck.py
tools/tune_ck.py
+44
-15
No files found.
src/include/migraphx/serialize.hpp
View file @
297645e8
...
@@ -212,22 +212,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
...
@@ -212,22 +212,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
x
=
from_value
<
T
>
(
v
);
x
=
from_value
<
T
>
(
v
);
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{}
or
std
::
is_enum
<
T
>
{}
)
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
{
{
x
=
v
.
to
<
T
>
();
x
=
v
.
to
<
T
>
();
}
}
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
{
x
=
v
.
to
<
T
>
();
}
inline
void
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
{
{
x
.
from_value
(
v
);
x
.
from_value
(
v
);
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
1
0
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
auto
from_value_impl
(
rank
<
1
1
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
{
{
migraphx_from_value
(
v
,
x
);
migraphx_from_value
(
v
,
x
);
}
}
...
@@ -243,7 +249,7 @@ value to_value(const T& x)
...
@@ -243,7 +249,7 @@ value to_value(const T& x)
template
<
class
T
>
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
void
from_value
(
const
value
&
v
,
T
&
x
)
{
{
detail
::
from_value_impl
(
rank
<
1
0
>
{},
v
,
x
);
detail
::
from_value_impl
(
rank
<
1
1
>
{},
v
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/CMakeLists.txt
View file @
297645e8
...
@@ -186,20 +186,19 @@ if(MIGRAPHX_USE_HIPRTC)
...
@@ -186,20 +186,19 @@ if(MIGRAPHX_USE_HIPRTC)
message
(
STATUS
"MIGraphX is using hipRTC"
)
message
(
STATUS
"MIGraphX is using hipRTC"
)
target_compile_definitions
(
migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1
)
target_compile_definitions
(
migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1
)
else
()
else
()
message
(
STATUS
"MIGraphX is using HIP Clang"
)
# Get flags needed to compile hip
# Get flags needed to compile hip
include
(
TargetFlags
)
include
(
TargetFlags
)
message
(
STATUS
"HIP COMPILER FLAGS:
${
HIP_COMPILER_FLAGS
}
"
)
message
(
STATUS
"HIP COMPILER FLAGS:
${
HIP_COMPILER_FLAGS
}
"
)
target_flags
(
HIP_COMPILER_FLAGS hip::device
)
target_flags
(
HIP_COMPILER_FLAGS hip::device
)
message
(
STATUS
"HIP COMPILER FLAGS:
${
HIP_COMPILER_FLAGS
}
"
)
# Remove cuda arch flags
# Remove cuda arch flags
string
(
REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE --offload-arch=[a-z0-9:+-]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE --offload-arch=[a-z0-9:+-]+
""
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
# Skip library paths since hip will incorrectly treat it as a source file
# Skip library paths since hip will incorrectly treat it as a source file
string
(
APPEND HIP_COMPILER_FLAGS
" "
)
string
(
APPEND HIP_COMPILER_FLAGS
" "
)
# Add ck includes
find_path
(
CK_INCLUDE_PATH ck/ck.hpp
)
message
(
STATUS
"CK path:
${
CK_INCLUDE_PATH
}
"
)
string
(
APPEND HIP_COMPILER_FLAGS
" -isystem
${
CK_INCLUDE_PATH
}
"
)
foreach
(
_unused RANGE 2
)
foreach
(
_unused RANGE 2
)
string
(
REGEX REPLACE
" /[^ ]+
\\
.(a|so) "
" "
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
string
(
REGEX REPLACE
" /[^ ]+
\\
.(a|so) "
" "
HIP_COMPILER_FLAGS
"
${
HIP_COMPILER_FLAGS
}
"
)
endforeach
()
endforeach
()
...
...
src/targets/gpu/fuse_ops.cpp
View file @
297645e8
...
@@ -681,6 +681,7 @@ struct find_contiguous_tranpose_precompile
...
@@ -681,6 +681,7 @@ struct find_contiguous_tranpose_precompile
{
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
name
(
"transpose"
)(
match
::
used_once
(),
match
::
arg
(
0
)(
match
::
name
(
"gpu::precompile_op"
)(
match
::
used_once
()).
bind
(
"op"
)))
match
::
arg
(
0
)(
match
::
name
(
"gpu::precompile_op"
)(
match
::
used_once
()).
bind
(
"op"
)))
.
bind
(
"transpose"
)));
.
bind
(
"transpose"
)));
}
}
...
@@ -693,12 +694,13 @@ struct find_contiguous_tranpose_precompile
...
@@ -693,12 +694,13 @@ struct find_contiguous_tranpose_precompile
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
auto
iperm
=
invert_permutation
(
perm
);
auto
s
=
auto
s
=
shape
::
from_permutation
(
shape
::
from_permutation
(
op_ins
->
get_shape
().
type
(),
op_ins
->
get_shape
().
lens
(),
i
perm
);
op_ins
->
get_shape
().
type
(),
op_ins
->
get_shape
().
lens
(),
perm
);
// perm or iperm?
auto
v
=
op_ins
->
get_operator
().
to_value
();
auto
v
=
op_ins
->
get_operator
().
to_value
();
v
[
"output_shape"
]
=
to_value
(
s
);
v
[
"output_shape"
]
=
to_value
(
s
);
auto
new_op
=
make_op
(
"gpu::precompile_op"
,
v
);
auto
new_op
=
make_op
(
"gpu::precompile_op"
,
v
);
m
.
replace_instruction
(
op_ins
,
new_op
,
op_ins
->
inputs
(),
op_ins
->
module_inputs
());
m
.
replace_instruction
(
op_ins
,
new_op
,
op_ins
->
inputs
(),
op_ins
->
module_inputs
());
assert
(
ins
->
get_shape
()
==
transpose
->
get_shape
());
m
.
replace_instruction
(
ins
,
transpose
);
m
.
replace_instruction
(
ins
,
transpose
);
}
}
};
};
...
...
tools/tune_ck.py
View file @
297645e8
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
,
multiprocessing
,
multiprocessing
.
dummy
ck_function
=
-
1
ck_function
=
-
1
...
@@ -23,10 +23,14 @@ def pretty_print(obj):
...
@@ -23,10 +23,14 @@ def pretty_print(obj):
def
run_driver
(
b
):
def
run_driver
(
b
):
print
(
b
)
print
(
b
)
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
if
not
os
.
path
.
exists
(
'./bin/gpu-driver'
):
print
(
"./bin/gpu-driver not found"
)
os
.
abort
()
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
capture_output
=
True
,
check
=
True
,
shell
=
True
)
shell
=
True
)
print
(
cp
.
stderr
.
decode
())
cp
.
check_returncode
()
for
line
in
cp
.
stdout
.
decode
().
split
(
"
\n
"
):
for
line
in
cp
.
stdout
.
decode
().
split
(
"
\n
"
):
s
=
line
.
strip
()
s
=
line
.
strip
()
if
not
s
:
if
not
s
:
...
@@ -45,23 +49,29 @@ def get_device_time(s):
...
@@ -45,23 +49,29 @@ def get_device_time(s):
return
convert_to_float
(
fields
[
-
1
].
strip
())
return
convert_to_float
(
fields
[
-
1
].
strip
())
def
benchmark_ck
(
config
,
name
,
tuning
):
def
run_driver_ck
(
config
,
tuning
,
iterations
):
try
:
b
=
{
b
=
{
'settings'
:
{
'settings'
:
{
'iterations'
:
iterations
'iterations'
:
100
},
},
'compile_op'
:
{
'compile_op'
:
{
'name'
:
'ck_gemm'
,
'name'
:
name
,
'check'
:
True
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'tuning_val'
:
tuning
,
'inputs'
:
config
'inputs'
:
config
}
}
}
for
line
in
run_driver
(
b
):
}
return
run_driver
(
b
)
def
benchmark_ck
(
config
,
tuning
):
try
:
for
line
in
run_driver_ck
(
config
,
tuning
,
100
):
dtime
=
get_device_time
(
line
)
dtime
=
get_device_time
(
line
)
print
(
dtime
)
print
(
dtime
)
return
float
(
dtime
)
return
float
(
dtime
)
print
(
"Failed"
)
sys
.
exit
(
1
)
except
:
except
:
return
sys
.
float_info
.
max
return
sys
.
float_info
.
max
...
@@ -86,6 +96,19 @@ def parse_log(f):
...
@@ -86,6 +96,19 @@ def parse_log(f):
yield
(
config
,
'ck_gemm_softmax_gemm'
)
yield
(
config
,
'ck_gemm_softmax_gemm'
)
def
precompile
(
x
):
try
:
list
(
run_driver_ck
(
x
[
0
],
x
[
1
],
0
))
except
:
pass
def
precompile_log
(
f
,
n
):
solutions
=
((
config
,
i
)
for
config
in
parse_log
(
f
)
for
i
in
range
(
n
))
with
multiprocessing
.
Pool
(
24
)
as
p
:
list
(
p
.
imap
(
precompile
,
solutions
))
def
benchmark_log
(
f
,
n
):
def
benchmark_log
(
f
,
n
):
result
=
[]
result
=
[]
for
config
,
name
in
parse_log
(
f
):
for
config
,
name
in
parse_log
(
f
):
...
@@ -107,12 +130,18 @@ def parse_args():
...
@@ -107,12 +130,18 @@ def parse_args():
type
=
str
,
type
=
str
,
metavar
=
'file'
,
metavar
=
'file'
,
help
=
'Output json file to save tunings'
)
help
=
'Output json file to save tunings'
)
parser
.
add_argument
(
'--precompile'
,
'-p'
,
action
=
'store_true'
,
help
=
'Precompile kernels first in parallel'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
help
=
'Number of instances to tune'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
run
(
args
):
def
run
(
args
):
if
(
args
.
precompile
):
precompile_log
(
args
.
log
,
args
.
n
)
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment