"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "75f5ed4ac1bcb5994f0d5e3e2d34790791e3d6a0"
Commit 923d1348 authored by Paul's avatar Paul
Browse files

Handle post pointwise

parent d4223287
...@@ -38,6 +38,7 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -38,6 +38,7 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const concat_kernel = R"__migraphx__( static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp> #include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
...@@ -47,7 +48,7 @@ extern "C" { ...@@ -47,7 +48,7 @@ extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...); concat<${axis}>(xs...)(op::id{}, y);
}); });
} }
......
...@@ -44,6 +44,14 @@ constexpr auto concat_slice(Output out, Input, Start) ...@@ -44,6 +44,14 @@ constexpr auto concat_slice(Output out, Input, Start)
return make_tensor_view(&out[offset], s); return make_tensor_view(&out[offset], s);
} }
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) {
f(concat_slice<Axis>(xs, input, start)...);
};
}
template <index_int Axis, class Input> template <index_int Axis, class Input>
constexpr auto concat_ends(Input) constexpr auto concat_ends(Input)
{ {
...@@ -51,15 +59,18 @@ constexpr auto concat_ends(Input) ...@@ -51,15 +59,18 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class Output, class... Inputs> template <index_int Axis, class... Inputs>
__device__ void concat(Output output, Inputs... inputs) __device__ auto concat(Inputs... inputs)
{ {
return [=](auto f, auto... ts) {
auto idx = make_index(); auto idx = make_index();
fold([&](auto start, auto input) { fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start); concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; }); idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input); return start + concat_ends<Axis>(input);
})(_c<0>, inputs...); })(_c<0>, inputs...);
};
} }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment