Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b7e80b6e
Commit
b7e80b6e
authored
Oct 19, 2022
by
Paul
Browse files
Format
parent
e34cb7c1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
40 additions
and
38 deletions
+40
-38
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+13
-10
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+17
-18
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+5
-5
test/verify/gemm_add_relu.cpp
test/verify/gemm_add_relu.cpp
+3
-3
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
b7e80b6e
...
@@ -38,7 +38,7 @@ struct ck_gemm
...
@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
for
(
const
auto
&
input
:
inputs
)
check_gemm_shape
(
input
);
check_gemm_shape
(
input
);
return
op
.
compute_shape
({
a
,
b
});
return
op
.
compute_shape
({
a
,
b
});
}
}
...
@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
return
false
;
return
false
;
if
(
a
.
lens
()[
1
]
>=
2048
)
if
(
a
.
lens
()[
1
]
>=
2048
)
return
false
;
return
false
;
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
b
.
lens
()[
1
]
%
8
==
0
);
b
.
lens
()[
1
]
%
8
==
0
);
...
@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct
find_ck_gemm
struct
find_ck_gemm
{
{
// Find a gemm followed by a pointwise operation.
// Find a gemm followed by a pointwise operation.
auto
matcher
()
const
{
auto
matcher
()
const
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
}
...
@@ -77,17 +79,18 @@ struct find_ck_gemm
...
@@ -77,17 +79,18 @@ struct find_ck_gemm
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
inputs
=
ins
->
inputs
();
auto
inputs
=
ins
->
inputs
();
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
assert
(
gemm_it
!=
inputs
.
end
());
if
(
gemm_idx
!=
0
)
if
(
gemm_idx
!=
0
)
{
{
// std::swap(inputs[0], inputs[gemm_idx]);
// std::swap(inputs[0], inputs[gemm_idx]);
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
".0"
,
gemm_param
->
get_shape
());
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
".0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
".0"
,
first_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
".0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
first_param
);
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
b7e80b6e
...
@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return
shape
::
cpp_type
(
s
.
type
());
return
shape
::
cpp_type
(
s
.
type
());
}
}
template
<
class
Iterator
,
class
F
>
template
<
class
Iterator
,
class
F
>
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
{
{
std
::
vector
<
std
::
string
>
s
;
std
::
vector
<
std
::
string
>
s
;
...
@@ -158,24 +158,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -158,24 +158,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
m
=
c_shape
.
lens
().
front
();
auto
m
=
c_shape
.
lens
().
front
();
auto
n
=
c_shape
.
lens
().
back
();
auto
n
=
c_shape
.
lens
().
back
();
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
auto
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
});
});
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
if
(
v
.
contains
(
"post"
))
{
{
assert
(
instance
[
2
]
==
"ck::Tuple<>"
);
assert
(
instance
[
2
]
==
"ck::Tuple<>"
);
instance
[
2
]
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
);
instance
[
2
]
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
);
assert
(
instance
[
8
]
==
"ck::Tuple<>"
);
assert
(
instance
[
8
]
==
"ck::Tuple<>"
);
instance
[
8
]
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
instance
[
8
]
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
assert
(
instance
[
12
]
==
"ck_passthrough"
);
assert
(
instance
[
12
]
==
"ck_passthrough"
);
instance
[
12
]
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
instance
[
12
]
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
}
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
get_block_size
(
instance
);
auto
block_size
=
get_block_size
(
instance
);
auto
grid_size
=
get_grid_size
(
instance
,
m
,
n
);
auto
grid_size
=
get_grid_size
(
instance
,
m
,
n
);
...
@@ -185,28 +184,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -185,28 +184,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
options
.
virtual_inputs
=
inputs
;
auto
src
=
interpolate_string
(
ck_gemm_kernel
,
{
auto
src
=
interpolate_string
(
ck_gemm_kernel
,
{
"instance"
,
join_strings
(
instance
,
","
)},
{{
"instance"
,
join_strings
(
instance
,
","
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}
{
"kernel"
,
options
.
kernel_name
}});
});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
v
[
"kernel"
]
=
"ck_gemm_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_kernel"
;
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);"
;
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_function"
)
+
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm>"
;
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
shapes
=
to_shapes
(
ins
->
inputs
());
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
b7e80b6e
...
@@ -50,13 +50,13 @@ constexpr bool is_row_major()
...
@@ -50,13 +50,13 @@ constexpr bool is_row_major()
template
<
class
T
>
template
<
class
T
>
using
to_ck_type
=
typename
detail
::
to_ck_type_impl
<
T
>::
type
;
using
to_ck_type
=
typename
detail
::
to_ck_type_impl
<
T
>::
type
;
template
<
class
T
>
template
<
class
T
>
constexpr
auto
to_ck_pointer
(
T
*
x
)
constexpr
auto
to_ck_pointer
(
T
*
x
)
{
{
return
static_cast
<
to_ck_type
<
T
>*>
(
x
);
return
static_cast
<
to_ck_type
<
T
>*>
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
constexpr
auto
to_ck_const_pointer
(
const
T
*
x
)
constexpr
auto
to_ck_const_pointer
(
const
T
*
x
)
{
{
return
static_cast
<
const
to_ck_type
<
T
>*>
(
x
);
return
static_cast
<
const
to_ck_type
<
T
>*>
(
x
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
b7e80b6e
...
@@ -36,12 +36,12 @@
...
@@ -36,12 +36,12 @@
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
#define MIGRAPHX_LIFT_CLASS(name, ...)
\
struct name \
struct name
\
{ \
{
\
template<class... PrivateLiftTs> \
template
<class... PrivateLiftTs>
\
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
\
}
}
namespace
migraphx
{
namespace
migraphx
{
...
...
test/verify/gemm_add_relu.cpp
View file @
b7e80b6e
...
@@ -33,9 +33,9 @@ struct gemm_add_relu : verify_program<gemm_add_relu>
...
@@ -33,9 +33,9 @@ struct gemm_add_relu : verify_program<gemm_add_relu>
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
half_type
,
{
16
,
8
}});
auto
a
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
half_type
,
{
16
,
8
}});
auto
b
=
mm
->
add_parameter
(
"2"
,
{
migraphx
::
shape
::
half_type
,
{
8
,
32
}});
auto
b
=
mm
->
add_parameter
(
"2"
,
{
migraphx
::
shape
::
half_type
,
{
8
,
32
}});
auto
c
=
mm
->
add_parameter
(
"3"
,
{
migraphx
::
shape
::
half_type
,
{
16
,
32
}});
auto
c
=
mm
->
add_parameter
(
"3"
,
{
migraphx
::
shape
::
half_type
,
{
16
,
32
}});
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
c
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
c
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment