Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c2263671
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "651677f5645831fe748a7eb7fe59a8b05c921a52"
Commit
c2263671
authored
Sep 06, 2022
by
Paul
Browse files
Generate pointwise post operator
parent
72011beb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
4 deletions
+33
-4
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+14
-0
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+19
-4
No files found.
src/targets/gpu/compile_hip.cpp
View file @
c2263671
...
@@ -299,6 +299,20 @@ std::string enum_params(std::size_t count, std::string param)
...
@@ -299,6 +299,20 @@ std::string enum_params(std::size_t count, std::string param)
return
join_strings
(
items
,
","
);
return
join_strings
(
items
,
","
);
}
}
// std::string enum_params(std::size_t count, std::initializer_list<std::string> params)
// {
// std::vector<std::string> items(count);
// transform(range(count), items.begin(), [&](auto i) {
// auto idx = std::to_string(i);
// std::vector<std::string> eparams(params.size());
// transform(params, eparams.begin(), [&](const std::string& s) {
// return s + i;
// });
// return join_strings(eparams, " ");
// });
// return join_strings(items, ",");
// }
#endif // MIGRAPHX_USE_HIPRTC
#endif // MIGRAPHX_USE_HIPRTC
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/concat.cpp
View file @
c2263671
...
@@ -43,12 +43,14 @@ static const char* const concat_kernel = R"__migraphx__(
...
@@ -43,12 +43,14 @@ static const char* const concat_kernel = R"__migraphx__(
namespace migraphx {
namespace migraphx {
${preamble}
extern "C" {
extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y,
${concat_params},
auto... xs) {
concat<${axis}>(
xs...)(op::id{}, y
);
concat<${axis}>(
${concat_args})(${post}, y, xs...
);
});
});
}
}
...
@@ -71,27 +73,40 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -71,27 +73,40 @@ struct concat_compiler : compiler<concat_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
// TODO: Use reduce_dims
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
()
-
1
;
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment