Commit a0921a37 authored by Paul's avatar Paul
Browse files

Format

parent be27f5cb
......@@ -63,9 +63,8 @@ struct concat_compiler : compiler<concat_compiler>
static std::size_t get_min_elements(const std::vector<shape>& inputs)
{
auto it = std::min_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) {
return s.elements();
}));
auto it = std::min_element(
inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { return s.elements(); }));
return it->elements();
}
......@@ -80,15 +79,13 @@ struct concat_compiler : compiler<concat_compiler>
auto vec = vectorize::elements(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v,
compute_global_for(ctx,
get_min_elements(options.inputs) / vec.size, 256));
v, compute_global_for(ctx, get_min_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
......
......@@ -31,35 +31,33 @@
namespace migraphx {
template<index_int Axis, class Output, class Input, class Start>
template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto strides = get_shape_c<Output>{}.strides;
constexpr auto offset = return_c([] {
constexpr auto offset = return_c([] {
constexpr auto output_shape = get_shape_c<Output>{};
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
}
template<index_int Axis, class Input>
template <index_int Axis, class Input>
constexpr auto concat_ends(Input)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>;
}
template<index_int Axis, class Output, class... Inputs>
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) {
y[i] = input[i];
});
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment