Commit fd212a64 authored by Paul's avatar Paul
Browse files

Fix concat with transposed inputs

parent 60629d87
...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream, ...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto&& arg = args[j]; auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j]; auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) { shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i)); auto input_idx = input_shape.multi(i);
output.data()[idx + offset] = input.data()[i]; auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment