Commit a0921a37 authored by Paul's avatar Paul
Browse files

Format

parent be27f5cb
...@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler>
static std::size_t get_min_elements(const std::vector<shape>& inputs) static std::size_t get_min_elements(const std::vector<shape>& inputs)
{ {
auto it = std::min_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { auto it = std::min_element(
return s.elements(); inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { return s.elements(); }));
}));
return it->elements(); return it->elements();
} }
...@@ -80,15 +79,13 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -80,15 +79,13 @@ struct concat_compiler : compiler<concat_compiler>
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, get_min_elements(options.inputs) / vec.size, 256));
compute_global_for(ctx,
get_min_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}}); {"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -31,35 +31,33 @@ ...@@ -31,35 +31,33 @@
namespace migraphx { namespace migraphx {
template<index_int Axis, class Output, class Input, class Start> template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start) constexpr auto concat_slice(Output out, Input, Start)
{ {
constexpr auto lens = get_shape_c<Input>{}.lens; constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto strides = get_shape_c<Output>{}.strides; constexpr auto strides = get_shape_c<Output>{}.strides;
constexpr auto offset = return_c([] { constexpr auto offset = return_c([] {
constexpr auto output_shape = get_shape_c<Output>{}; constexpr auto output_shape = get_shape_c<Output>{};
return Start{} * output_shape.strides[Axis]; return Start{} * output_shape.strides[Axis];
}); });
constexpr auto s = make_shape(lens, strides); constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s); return make_tensor_view(&out[offset], s);
} }
template<index_int Axis, class Input> template <index_int Axis, class Input>
constexpr auto concat_ends(Input) constexpr auto concat_ends(Input)
{ {
constexpr auto lens = get_shape_c<Input>{}.lens; constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template<index_int Axis, class Output, class... Inputs> template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs) __device__ void concat(Output output, Inputs... inputs)
{ {
auto idx = make_index(); auto idx = make_index();
fold([&](auto start, auto input) { fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start); auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
y[i] = input[i];
});
return start + concat_ends<Axis>(input); return start + concat_ends<Axis>(input);
})(_c<0>, inputs...); })(_c<0>, inputs...);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment