Unverified Commit 2d4dcc47 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Non std shape auto contiguous (#1001)

Resolves a problem in parsing the ssd-10 model.

The problem is, after inserting contiguous in the auto_contiguous pass, standard output shape of some operators becomes non-standard. Then, if the next operator requires standard input shape, an exception is throw.

For example, if we pass the following model:
Input (standard shape) -> transpose (transposed) -> softmax (transposed) -> transpose (standard) -> gather.
It works fine, and no contiguous is required.

In the auto_contiguous pass, a contiguous is inserted after the first transpose. Then we need to replace the first transpose with the contiguous and recompute all shapes. When it comes to the gather operator, its input is a transposed shape, and an exception is thrown.

The solution is in the recompute_shape() function. If it is called by the auto_contiguous pass and shape of an instruction is changed, and the shape is non_standard, we do not recompute shape of its output. The reason is: since its output shape is non_standard, a contiguous op will be added after the instruction, which will recompute shape for later operators.
parent 2788f647
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -38,7 +38,7 @@ struct gather ...@@ -38,7 +38,7 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis); lens.erase(lens.begin() + axis);
......
File mode changed from 100755 to 100644
...@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast) ...@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast)
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
} }
TEST_CASE(two_transpose_gather)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m1.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto sd = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), td);
auto bd =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), bd, ind);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m2.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1473,6 +1473,32 @@ TEST_CASE(fp32_fp16_test) ...@@ -1473,6 +1473,32 @@ TEST_CASE(fp32_fp16_test)
test_case({"add"}); test_case({"add"});
} }
TEST_CASE(gather_non_std_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {0.5f, 3.5f, 6.5f, 1.5f, 4.5f, 7.5f, 2.5f, 2.5f, 8.5f};
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto d = mm->add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{-3, -3, -1, -1};
auto ind = mm->add_literal(migraphx::literal{s_indices, indices});
auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d);
auto tind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), td, tind);
auto result = p.eval({}).back();
std::vector<float> golden = {
0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f, 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
std::vector<float> res_data;
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(gather_test) TEST_CASE(gather_test)
{ {
{ {
...@@ -2784,7 +2810,6 @@ TEST_CASE(nms_not_center_test) ...@@ -2784,7 +2810,6 @@ TEST_CASE(nms_not_center_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
...@@ -2818,7 +2843,6 @@ TEST_CASE(nms_test) ...@@ -2818,7 +2843,6 @@ TEST_CASE(nms_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonstd_gather : verify_program<test_nonstd_gather>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 1, 0, 2};
auto d = mm->add_parameter("data", s);
auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d);
auto ind = mm->add_literal(migraphx::literal{s_indices, indices});
auto tind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), td, tind);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment