Unverified Commit 04065c64 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Step op (#839)



* add the operator step

* clang formatJ

* add unit tests

* clang format

* add more unit test for step op

* clang format

* add more unit tests

* clang format

* fix review comments

* clang format

* rename two unit tests
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent de69008b
......@@ -151,6 +151,7 @@ register_migraphx_ops(
sqdiff
sqrt
squeeze
step
sub
tanh
tan
......
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#include "migraphx/stringutils.hpp"
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct step
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
std::string name() const { return "step"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto in_lens = input.lens();
auto t = input.type();
if(axes.size() != steps.size())
{
MIGRAPHX_THROW("STEP: attribute axes {" + to_string_range(axes) +
"} has different dimensions from step {" + to_string_range(steps) +
"}.");
}
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); }))
{
MIGRAPHX_THROW("STEP: axis value is out of range!");
}
auto lens = in_lens;
auto strides = input.strides();
for(auto i : range(axes.size()))
{
auto axis = axes[i];
auto step = steps[i];
lens[axis] = (in_lens[axis] + step - 1) / step;
strides[axis] *= step;
}
return {t, lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -88,6 +88,7 @@
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/step.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
......@@ -670,19 +663,6 @@ struct find_add_convs
return x.stride[0] / y.stride[0];
}
static shape compute_stride_shape(const shape& input, std::size_t n)
{
return {input.type(),
{input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(1, (input.lens()[2] - 1) / n + 1)),
std::size_t(std::max<std::ptrdiff_t>(1, (input.lens()[3] - 1) / n + 1))},
{input.strides()[0],
input.strides()[1],
input.strides()[2] * n,
input.strides()[3] * n}};
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
......@@ -713,11 +693,7 @@ struct find_add_convs
return;
new_op = a_op;
b_input = p.insert_instruction(
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(b_input->get_shape(), n))}}),
b_input);
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
}
else if(b_op.stride < a_op.stride)
{
......@@ -726,11 +702,7 @@ struct find_add_convs
return;
new_op = b_op;
a_input = p.insert_instruction(
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(a_input->get_shape(), n))}}),
a_input);
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
}
else
return;
......
......@@ -1549,4 +1549,23 @@ TEST_CASE(prefix_scan_sum)
}
}
TEST_CASE(step_test)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}};
{
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2}, {8, 8, 3}};
expect_shape(s2, migraphx::make_op("step", {{"axes", {1, 2}}, {"steps", {2, 3}}}), s1);
}
{
migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}};
throws_shape(migraphx::make_op("step", {{"axes", {1, 2}}, {"steps", {1}}}), s1);
}
{
migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}};
throws_shape(migraphx::make_op("step", {{"axes", {2, 3}}, {"steps", {2, 3}}}), s1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3847,6 +3847,42 @@ TEST_CASE(squeeze_test)
}
}
TEST_CASE(step_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 4 * 6);
std::iota(data.begin(), data.end(), 2);
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 2, 3}}, {"steps", {2, 2, 3}}}), l0);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 2}};
EXPECT(result.get_shape() == s2);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 4 * 6);
std::iota(data.begin(), data.end(), 2);
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 2, 1}};
EXPECT(result.get_shape() == s2);
}
}
TEST_CASE(sub_test)
{
migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_step : verify_program<test_step>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_parameter("x", s1);
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 2, 3}}, {"steps", {2, 2, 3}}}), l0);
mm->add_return({r});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_step_broadcast_transpose : verify_program<test_step_broadcast_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 6}};
auto l0 = mm->add_parameter("x", s1);
auto ml = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 1, 4, 6}}}), l0);
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ml);
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl);
mm->add_return({r});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_step_transpose : verify_program<test_step_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_parameter("x", s1);
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment