Commit fd212a64 authored by Paul's avatar Paul
Browse files

Fix concat with transposed inputs

parent 60629d87
......@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) {
shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i));
output.data()[idx + offset] = input.data()[i];
auto input_idx = input_shape.multi(i);
auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment