Unverified Commit 548783c8 authored by turneram's avatar turneram Committed by GitHub
Browse files

Complete GPU implementation of CumSum op (#1094)

Add exclusive and reverse modes to gpu implementation of prefix_scan_sum, which completes support for ONNX op CumSum
parent e521fa3f
...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived> ...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.at(0); auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
} }
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
{ {
argument result = args[0].copy(); result = args[0].copy();
auto s = result.get_shape(); }
else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens(); auto lens = s.lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -53,6 +53,12 @@ __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, O ...@@ -53,6 +53,12 @@ __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, O
output); output);
} }
template <class F>
constexpr auto reverse_scan(index_int n, F f)
{
return [=](auto i, auto&&... xs) { return f(n - i - 1, xs...); };
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/device/prefix_scan_sum.hpp> #include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp> #include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp> #include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,14 +9,91 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,14 +9,91 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis) void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse)
{ {
const index_int block_size = 256; const index_int max_block_size = 256;
const index_int n = arg.get_shape().lens()[axis]; const index_int n = arg.get_shape().lens()[axis];
auto rlens = result.get_shape().lens(); auto rlens = result.get_shape().lens();
rlens[axis] = 1; rlens[axis] = 1;
hip_visit_all(result, arg, result.get_shape().with_lens(rlens))( hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
[=](auto output, auto input, auto rshape) { [=](auto output, auto input, auto rshape) {
const index_int block_size = compute_block_size(rshape.elements(), max_block_size);
if(reverse and exclusive)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) {
if(j == n - 1)
output[compute_idx(j)] = 0;
if(j > 0)
output[compute_idx(j - 1)] = x;
}));
});
}
else if(reverse)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; }));
});
}
else if(exclusive)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) {
auto k = j + 1;
if(j == 0)
output[compute_idx(0)] = 0;
if(k < n)
output[compute_idx(k)] = x;
});
});
}
else
{
gs_launch(stream, rshape.elements() * block_size, block_size)( gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ { [=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size); const auto ridx = rshape.multi(i / block_size);
...@@ -24,7 +102,7 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& ...@@ -24,7 +102,7 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument&
k[axis] = j; k[axis] = j;
return k; return k;
}; };
block_scan<block_size>( block_scan<max_block_size>(
idx, idx,
sum{}, sum{},
0, 0,
...@@ -32,6 +110,7 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& ...@@ -32,6 +110,7 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument&
[&](auto j) { return input[compute_idx(j)]; }, [&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; }); [&](auto j, auto x) { output[compute_idx(j)] = x; });
}); });
}
}); });
} }
......
...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis); void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum> ...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
if(op.exclusive or op.reverse) device::prefix_scan_sum(
MIGRAPHX_THROW("Exclusive and reverse scan not supported"); ctx.get_stream().get(), args[1], args[0], op.axis, op.exclusive, op.reverse);
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
return args[1]; return args[1];
} }
......
...@@ -9,10 +9,12 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm ...@@ -9,10 +9,12 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {1}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), xb);
return p; return p;
} }
}; };
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_exclusive : verify_program<test_prefix_scan_sum_exclusive>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum",
{{"axis", 2}, {"exclusive", true}, {"reverse", false}}),
x);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_exclusive_reverse
: verify_program<test_prefix_scan_sum_exclusive_reverse>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum",
{{"axis", 0}, {"exclusive", true}, {"reverse", true}}),
x);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_reverse : verify_program<test_prefix_scan_sum_reverse>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum",
{{"axis", 1}, {"exclusive", false}, {"reverse", true}}),
x);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment