Commit a0921a37 authored by Paul's avatar Paul
Browse files

Format

parent be27f5cb
...@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler>
static std::size_t get_min_elements(const std::vector<shape>& inputs) static std::size_t get_min_elements(const std::vector<shape>& inputs)
{ {
auto it = std::min_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { auto it = std::min_element(
return s.elements(); inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { return s.elements(); }));
}));
return it->elements(); return it->elements();
} }
...@@ -80,9 +79,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -80,9 +79,7 @@ struct concat_compiler : compiler<concat_compiler>
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, get_min_elements(options.inputs) / vec.size, 256));
compute_global_for(ctx,
get_min_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace migraphx { namespace migraphx {
template<index_int Axis, class Output, class Input, class Start> template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start) constexpr auto concat_slice(Output out, Input, Start)
{ {
constexpr auto lens = get_shape_c<Input>{}.lens; constexpr auto lens = get_shape_c<Input>{}.lens;
...@@ -44,22 +44,20 @@ constexpr auto concat_slice(Output out, Input, Start) ...@@ -44,22 +44,20 @@ constexpr auto concat_slice(Output out, Input, Start)
return make_tensor_view(&out[offset], s); return make_tensor_view(&out[offset], s);
} }
template<index_int Axis, class Input> template <index_int Axis, class Input>
constexpr auto concat_ends(Input) constexpr auto concat_ends(Input)
{ {
constexpr auto lens = get_shape_c<Input>{}.lens; constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template<index_int Axis, class Output, class... Inputs> template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs) __device__ void concat(Output output, Inputs... inputs)
{ {
auto idx = make_index(); auto idx = make_index();
fold([&](auto start, auto input) { fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start); auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
y[i] = input[i];
});
return start + concat_ends<Axis>(input); return start + concat_ends<Axis>(input);
})(_c<0>, inputs...); })(_c<0>, inputs...);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment