Unverified Commit e8ae23b1 authored by turneram's avatar turneram Committed by GitHub
Browse files

Prefix scan operator (#797)



* Add scan struct; add initial tests; initial algorithm by cases; refactor into one algorithm; clean up code

* Rename; restructure; begin adding additional attributes

* refactor to use shape_for_each; temporarily drop reverse mode

* Add back reverse mode with shape_for_each_reverse; update tests; add axis bounds check

* Begin adding to onnx parser

* Add to onnx parser

* Fix onnx test

* Fix CI warnings

* Update algorithm to use slice+par_for; update gen_onnx; remove .o files; remove redundant axis normalizing

* Add exclusive mode

* Add reverse mode

* Remove .pyc file

* Fix warning

* Remove shape_for_each_reverse; clean up pointer usage for exclusive cases

* Remove unused variable

* Fix onnx test

* Add test case to op_shape_test

* Formatting

* Formatting

* Fix tidy warning

* Formatting

* Formatting

* Formatting

* Increase code coverage

* Formatting

* refine the script for creating the cumsum onnx file

* Alphabetize includes for operators.hpp

* Revise onnx test

* Remove redundant bounds check

* Formatting and style

* Alphabetize tests

* Remove duplicate tests from merge

* Fix tidy warning for sub_test
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 658cdab0
......@@ -121,6 +121,7 @@ register_migraphx_ops(
pad
pooling
pow
prefix_scan_sum
prelu
quant_convolution
quant_dot
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct prefix_scan_op : op_name<Derived>
{
int64_t axis;
bool exclusive = false;
bool reverse = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axis, "axis"), f(self.exclusive, "exclusive"), f(self.reverse, "reverse"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.at(0);
}
argument compute(const shape&, std::vector<argument> args) const
{
argument result = args[0];
auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
result.visit([&](auto output) {
using type = decltype(output);
par_for(batch.elements(), [&](auto i) {
auto* start = output.data() + batch.index(i);
type x{slice, start};
if(reverse)
{
if(exclusive)
{
std::copy(++x.begin(), x.end(), x.begin());
x.back() = 0;
}
std::partial_sum(std::make_reverse_iterator(x.end()),
std::make_reverse_iterator(x.begin()),
std::make_reverse_iterator(x.end()),
self.op());
}
else
{
if(exclusive)
{
std::copy_backward(x.begin(), --x.end(), x.end());
x.front() = 0;
}
std::partial_sum(x.begin(), x.end(), x.begin(), self.op());
}
});
});
return result;
}
auto init() const {}
prefix_scan_op() : axis(0) {}
prefix_scan_op(int64_t ax) : axis(ax) {}
prefix_scan_op(int64_t ax, bool excl) : axis(ax), exclusive(excl) {}
prefix_scan_op(int64_t ax, bool excl, bool rev) : axis(ax), exclusive(excl), reverse(rev) {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/prefix_scan_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct prefix_scan_sum : prefix_scan_op<prefix_scan_sum>
{
prefix_scan_sum() {}
prefix_scan_sum(int64_t ax) : prefix_scan_op(ax) {}
prefix_scan_sum(int64_t ax, bool excl) : prefix_scan_op(ax, excl) {}
prefix_scan_sum(int64_t ax, bool excl, bool rev) : prefix_scan_op(ax, excl, rev) {}
auto op() const
{
return [](auto x, auto y) { return x + y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -58,10 +58,11 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp>
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_prefix_scan_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
migraphx::argument in = args[1]->eval();
check_arg_empty(in, "PARSE_PREFIX_SCAN: axis - dynamic shape not supported");
std::vector<std::size_t> axis_in;
in.visit([&](auto input) { axis_in.assign(input.begin(), input.end()); });
int64_t axis = axis_in[0];
bool exclusive = false;
bool reverse = false;
if(contains(info.attributes, "exclusive"))
{
exclusive = parser.parse_value(info.attributes.at("exclusive")).at<bool>();
}
if(contains(info.attributes, "reverse"))
{
reverse = parser.parse_value(info.attributes.at("reverse")).at<bool>();
}
return info.add_instruction(
make_op(op_name, {{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}),
args[0]);
}
struct parse_prefix_scan_op : op_parser<parse_prefix_scan_op>
{
std::vector<op_desc> operators() const { return {{"CumSum", "prefix_scan_sum"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_prefix_scan_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2642,6 +2642,23 @@ def pow_i64_fp32_test():
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def prefix_scan_sum_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2])
axis_val = np.array([0])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis_val.shape,
vals=axis_val.astype(int))
node = onnx.helper.make_node('CumSum',
inputs=['x', 'axis'],
outputs=['y'],
exclusive=1,
reverse=1)
return ([node], [x], [y], [axis_tensor])
@onnx_test
def prelu_brcst_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -2278,6 +2278,21 @@ TEST_CASE(pow_i64_fp32_test)
EXPECT(p == prog);
}
TEST_CASE(prefix_scan_sum)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}, {1}}, {0}});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ret = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}),
l0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("prefix_scan_sum_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
......
......@@ -1532,4 +1532,21 @@ TEST_CASE(lstm)
}
}
TEST_CASE(prefix_scan_sum)
{
{
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
{
migraphx::shape s{migraphx::shape::float_type, {1, 2}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2727,6 +2727,371 @@ TEST_CASE(pow_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(prefix_scan_sum_1d)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_2d)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0};
EXPECT(results_vector == gold);
}
}
TEST_CASE(prefix_scan_sum_3d)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
2.0,
3.0,
1.0,
2.0,
3.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
2.0,
4.0,
6.0,
2.0,
4.0,
6.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
3.0,
6.0,
9.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
3.0,
6.0,
9.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 2}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0};
EXPECT(results_vector == gold);
}
}
TEST_CASE(prefix_scan_sum_exclusive)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.0, 1.0, 3.0, 6.0, 10.0, 11.0, 13.0, 16.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", true}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.0,
0.0,
0.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
0.0,
0.0,
0.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0};
EXPECT(results_vector == gold);
}
}
TEST_CASE(prefix_scan_sum_exclusive_reverse)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}),
l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{20.0, 18.0, 15.0, 11.0, 6.0, 0.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_negative_axis)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
2.0,
3.0,
1.0,
2.0,
3.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
2.0,
4.0,
6.0,
2.0,
4.0,
6.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", -2}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
3.0,
6.0,
9.0,
1.0,
2.0,
3.0,
2.0,
4.0,
6.0,
3.0,
6.0,
9.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", -1}, {"exclusive", false}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0,
1.0,
3.0,
6.0};
EXPECT(results_vector == gold);
}
}
TEST_CASE(prefix_scan_sum_reverse)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum",
{{"axis", 0}, {"exclusive", false}, {"reverse", true}}),
l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{20.0, 19.0, 17.0, 14.0, 10.0, 9.0, 7.0, 4.0};
EXPECT(results_vector == gold);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum",
{{"axis", 0}, {"exclusive", false}, {"reverse", true}}),
l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.0, 4.0, 6.0, 8.0, 1.0, 2.0, 3.0, 4.0};
EXPECT(results_vector == gold);
}
}
TEST_CASE(prelu_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment