Unverified Commit e64b773f authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Support nonstandard shapes for the Squeeze Op (#1068)

Support slice, broadcast and transpose shapes for the squeeze op.
parent a30ec101
......@@ -37,43 +37,49 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
std::copy_if(old_lens.begin(),
old_lens.end(),
std::back_inserter(new_lens),
[](auto len) { return len != 1; });
for(auto i : range(old_lens.size()))
{
if(old_lens[i] != 1)
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
else
{
for(std::size_t i = 0; i < old_lens.size(); i++)
for(auto i : range(old_lens.size()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens};
return shape{type, new_lens, new_strides};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
......
......@@ -1446,6 +1446,27 @@ TEST_CASE(test_squeeze_all)
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4}, {4, 1}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_multibroadcast)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_slice)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 6, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
......
......@@ -3,21 +3,17 @@
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE(argmax_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
......@@ -35,12 +31,8 @@ TEST_CASE(argmax_test_nonstd_shape)
TEST_CASE(argmin_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
......@@ -55,4 +47,62 @@ TEST_CASE(argmin_test_nonstd_shape)
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(squeeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 1, 3, 1, 3}}));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
// contiguous is required to read the values in standard shaped order
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}});
EXPECT(result == std_expected_result);
}
TEST_CASE(squeeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 1, 3}}));
auto l0_brcst = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}});
EXPECT(result == std_expected_result);
}
TEST_CASE(squeeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 4, 3}}));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_slice);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
EXPECT(result == std_expected_result);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment