"tools/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "102c6bdbc7855ea5c4818626df050b5202704cb1"
Commit 7c63b13b authored by charlie's avatar charlie
Browse files

Dynamic shape tests

parent dfa26315
...@@ -220,6 +220,7 @@ std::size_t shape::index(std::initializer_list<std::size_t> l) const ...@@ -220,6 +220,7 @@ std::size_t shape::index(std::initializer_list<std::size_t> l) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(const std::vector<std::size_t>& l) const std::size_t shape::index(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic()) if(this->dynamic())
...@@ -230,6 +231,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -230,6 +231,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
if(this->dynamic()) if(this->dynamic())
......
...@@ -75,6 +75,61 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -75,6 +75,61 @@ TEST_CASE(test_shape_dynamic_not_fixed)
EXPECT(s.dyn_dims().at(0).has_optimal()); EXPECT(s.dyn_dims().at(0).has_optimal());
} }
TEST_CASE(test_shape_dynamic_compares)
{
auto a = migraphx::shape::dynamic_dimension{2, 5, 2};
auto b = a;
auto c = migraphx::shape::dynamic_dimension{2, 5, 2};
auto d = migraphx::shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
EXPECT(a == c);
EXPECT(a != d);
migraphx::shape s0{migraphx::shape::float_type, {a, d}};
migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::float_type, {a, d}};
migraphx::shape s3{migraphx::shape::int32_type, {a}};
EXPECT(s0 == s1);
EXPECT(s0 == s2);
EXPECT(s0 != s3);
}
TEST_CASE(test_shape_dynamic_errors)
{
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 5, 2});
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(test::throws([&] { s.element_space(); }));
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); }));
EXPECT(test::throws([&] { s.index({0, 1}); }));
EXPECT(test::throws([&] { s.index(1); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(migraphx::shape::float_type, {3, 5}); }));
}
TEST_CASE(test_shape_dynamic_serialize)
{
std::vector<migraphx::shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(migraphx::shape::dynamic_dimension{2, 5, 2});
dims1.emplace_back(migraphx::shape::dynamic_dimension{2, 8, 0});
migraphx::shape s1{migraphx::shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(migraphx::shape::dynamic_dimension{2, 5, 2});
migraphx::shape s2{migraphx::shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
auto s3 = migraphx::from_value<migraphx::shape>(v1);
EXPECT(s3 == s1);
auto s4 = migraphx::from_value<migraphx::shape>(v2);
EXPECT(s4 == s2);
EXPECT(s3 != s4);
}
TEST_CASE(test_shape_packed) TEST_CASE(test_shape_packed)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment