Commit 235a463f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'gather_operator' into extend_gemm_op

parents 828017fb 39ca6601
......@@ -768,7 +768,7 @@ struct gather
// for scalar output
if(lens.empty())
{
return {type, {1}, {0}};
return {type};
}
return {type, lens};
......
......@@ -439,7 +439,7 @@ struct onnx_parser
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type(), {1}, {0}};
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
......
......@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
......
......@@ -173,7 +173,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
......@@ -194,7 +194,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
......
......@@ -1112,7 +1112,7 @@ struct test_gather_scalar_output
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
......@@ -1128,7 +1128,7 @@ struct test_gather_scalar_index
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
......
......@@ -524,7 +524,7 @@ TEST_CASE(constant_test)
TEST_CASE(constant_test_scalar)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, {1}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog);
......
......@@ -263,7 +263,7 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis},
......@@ -273,7 +273,7 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis},
......@@ -283,9 +283,9 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}, {0}},
expect_shape(migraphx::shape{migraphx::shape::float_type},
migraphx::op::gather{axis},
input,
indices);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment