Commit 39ca6601 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change one of the shape constructors to simplify code.

parent 4c1a1d63
...@@ -768,7 +768,7 @@ struct gather ...@@ -768,7 +768,7 @@ struct gather
// for scalar output // for scalar output
if(lens.empty()) if(lens.empty())
{ {
return {type, {1}, {0}}; return {type};
} }
return {type, lens}; return {type, lens};
......
...@@ -439,7 +439,7 @@ struct onnx_parser ...@@ -439,7 +439,7 @@ struct onnx_parser
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type(), {1}, {0}}; migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()}); return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
} }
......
...@@ -19,7 +19,7 @@ struct shape_impl ...@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {} shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {} shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
......
...@@ -173,7 +173,7 @@ TEST_CASE(gather_test) ...@@ -173,7 +173,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data}); auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index // scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0}; std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
...@@ -194,7 +194,7 @@ TEST_CASE(gather_test) ...@@ -194,7 +194,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data}); auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index // scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0}; std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
......
...@@ -1074,7 +1074,7 @@ struct test_gather_scalar_output ...@@ -1074,7 +1074,7 @@ struct test_gather_scalar_output
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1}; std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
...@@ -1090,7 +1090,7 @@ struct test_gather_scalar_index ...@@ -1090,7 +1090,7 @@ struct test_gather_scalar_index
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1}; std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
......
...@@ -524,7 +524,7 @@ TEST_CASE(constant_test) ...@@ -524,7 +524,7 @@ TEST_CASE(constant_test)
TEST_CASE(constant_test_scalar) TEST_CASE(constant_test_scalar)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx"); auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -263,7 +263,7 @@ TEST_CASE(gather) ...@@ -263,7 +263,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4; int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
...@@ -273,7 +273,7 @@ TEST_CASE(gather) ...@@ -273,7 +273,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = 3; int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
...@@ -283,9 +283,9 @@ TEST_CASE(gather) ...@@ -283,9 +283,9 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = 0; int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, expect_shape(migraphx::shape{migraphx::shape::float_type},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment