Commit c497c12d authored by charlie's avatar charlie
Browse files

element_space, min,max,opt _lens change

parent 33e5534c
...@@ -65,13 +65,6 @@ struct shape ...@@ -65,13 +65,6 @@ struct shape
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::size_t opt = 0;
dynamic_dimension() = default;
dynamic_dimension(std::size_t i_min, std::size_t i_max, std::size_t i_opt)
: min(i_min), max(i_max), opt(i_opt)
{
}
bool is_fixed() const { return min == max; } bool is_fixed() const { return min == max; }
bool has_optimal() const { return opt != 0; } bool has_optimal() const { return opt != 0; }
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y) friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
...@@ -132,8 +125,8 @@ struct shape ...@@ -132,8 +125,8 @@ struct shape
std::size_t elements() const; std::size_t elements() const;
/*! /*!
* Return the number of total bytes used for storage of the tensor data. * Return the number of total bytes used for storage of the tensor data; includes subshapes.
* Includes subshapes. * For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/ */
std::size_t bytes() const; std::size_t bytes() const;
...@@ -145,9 +138,23 @@ struct shape ...@@ -145,9 +138,23 @@ struct shape
const std::vector<dynamic_dimension>& dyn_dims() const; const std::vector<dynamic_dimension>& dyn_dims() const;
std::vector<std::size_t> min_dyn_dims() const; /*!
std::vector<std::size_t> max_dyn_dims() const; * Minimum lengths for dynamic shape.
std::vector<std::size_t> opt_dyn_dims() const; * lens() for fixed shape.
*/
std::vector<std::size_t> min_lens() const;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> max_lens() const;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
...@@ -169,7 +176,7 @@ struct shape ...@@ -169,7 +176,7 @@ struct shape
std::vector<std::size_t> multi(std::size_t i) const; std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed (number of elements and buffer size the same) with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
...@@ -288,7 +295,10 @@ struct shape ...@@ -288,7 +295,10 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
/// size of the data buffer /*!
* Returns size of the data buffer.
* Assuming a packed shape, returns maximum size of the data buffer for dynamic shape.
*/
std::size_t element_space() const; std::size_t element_space() const;
private: private:
......
...@@ -58,7 +58,7 @@ struct shape_impl ...@@ -58,7 +58,7 @@ struct shape_impl
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
std::vector<shape::shape::dynamic_dimension> m_dyn_dims = {}; std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides() void calculate_strides()
{ {
...@@ -77,7 +77,13 @@ struct shape_impl ...@@ -77,7 +77,13 @@ struct shape_impl
{ {
if(not m_dyn_dims.empty()) if(not m_dyn_dims.empty())
{ {
MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape"); auto maxes = max_lens();
return std::accumulate(
maxes.begin(),
maxes.end(),
std::size_t{1},
std::multiplies<>()
);
} }
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
...@@ -106,6 +112,51 @@ struct shape_impl ...@@ -106,6 +112,51 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(
m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x)
{
return x.min;
}
);
return ret;
}
std::vector<std::size_t> max_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(
m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x)
{
return x.max;
}
);
return ret;
}
std::vector<std::size_t> opt_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(
m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x)
{
return x.opt;
}
);
return ret;
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); } std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
...@@ -188,10 +239,6 @@ std::size_t shape::elements() const { return impl->elements(); } ...@@ -188,10 +239,6 @@ std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: bytes() called on dynamic shape");
}
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -390,37 +437,31 @@ bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); } ...@@ -390,37 +437,31 @@ bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_dyn_dims() const std::vector<std::size_t> shape::min_lens() const
{ {
auto num_dims = dyn_dims().size(); if (not this->dynamic())
std::vector<std::size_t> ret(num_dims);
for(int i = 0; i < num_dims; ++i)
{ {
ret.at(i) = dyn_dims().at(i).min; return this->lens();
} }
return ret; return impl->min_lens();;
} }
std::vector<std::size_t> shape::max_dyn_dims() const std::vector<std::size_t> shape::max_lens() const
{ {
auto num_dims = dyn_dims().size(); if (not this->dynamic())
std::vector<std::size_t> ret(num_dims);
for(int i = 0; i < num_dims; ++i)
{ {
ret.at(i) = dyn_dims().at(i).max; return this->lens();
} }
return ret; return impl->max_lens();;
} }
std::vector<std::size_t> shape::opt_dyn_dims() const std::vector<std::size_t> shape::opt_lens() const
{ {
auto num_dims = dyn_dims().size(); if (not this->dynamic())
std::vector<std::size_t> ret(num_dims);
for(int i = 0; i < num_dims; ++i)
{ {
ret.at(i) = dyn_dims().at(i).opt; return this->lens();
} }
return ret; return impl->opt_lens();;
} }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
...@@ -480,9 +521,9 @@ void migraphx_to_value(value& v, const shape& s) ...@@ -480,9 +521,9 @@ void migraphx_to_value(value& v, const shape& s)
if(s.dynamic()) if(s.dynamic())
{ {
result["dynamic"] = migraphx::to_value(s.dynamic()); result["dynamic"] = migraphx::to_value(s.dynamic());
result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims()); result["min_lens"] = migraphx::to_value(s.min_lens());
result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims()); result["max_lens"] = migraphx::to_value(s.max_lens());
result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims()); result["opt_lens"] = migraphx::to_value(s.opt_lens());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
} }
else else
...@@ -505,9 +546,9 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -505,9 +546,9 @@ void migraphx_from_value(const value& v, shape& s)
{ {
if(v.contains("dynamic")) if(v.contains("dynamic"))
{ {
auto mins = v.at("min_dyn_dims").to_vector<std::size_t>(); auto mins = v.at("min_lens").to_vector<std::size_t>();
auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>(); auto maxes = v.at("max_lens").to_vector<std::size_t>();
auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>(); auto opts = v.at("opt_lens").to_vector<std::size_t>();
assert(mins.size() == maxes.size() and maxes.size() == opts.size()); assert(mins.size() == maxes.size() and maxes.size() == opts.size());
auto num_dims = mins.size(); auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims(num_dims); std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
......
...@@ -56,9 +56,10 @@ TEST_CASE(test_shape_dynamic_fixed) ...@@ -56,9 +56,10 @@ TEST_CASE(test_shape_dynamic_fixed)
TEST_CASE(test_shape_dynamic_not_fixed) TEST_CASE(test_shape_dynamic_not_fixed)
{ {
migraphx::shape::dynamic_dimensions dims = {}; using migraphx::shape;
dims.emplace_back(2, 5, 2); std::vector<shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 8, 0); dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
...@@ -72,55 +73,61 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -72,55 +73,61 @@ TEST_CASE(test_shape_dynamic_not_fixed)
TEST_CASE(test_shape_dynamic_compares) TEST_CASE(test_shape_dynamic_compares)
{ {
auto a = migraphx::shape::dynamic_dimension{2, 5, 2}; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto b = a; auto b = a;
auto c = migraphx::shape::dynamic_dimension{2, 5, 2}; auto c = shape::dynamic_dimension{2, 5, 2};
auto d = migraphx::shape::dynamic_dimension{3, 8, 4}; auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b); EXPECT(a == b);
EXPECT(a == c); EXPECT(a == c);
EXPECT(a != d); EXPECT(a != d);
migraphx::shape s0{migraphx::shape::float_type, {a, d}}; migraphx::shape s0{shape::float_type, {a, d}};
migraphx::shape s1 = s0; migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::float_type, {a, d}}; migraphx::shape s2{shape::float_type, {a, d}};
migraphx::shape s3{migraphx::shape::int32_type, {a}}; migraphx::shape s3{shape::int32_type, {a}};
EXPECT(s0 == s1); EXPECT(s0 == s1);
EXPECT(s0 == s2); EXPECT(s0 == s2);
EXPECT(s0 != s3); EXPECT(s0 != s3);
} }
TEST_CASE(test_shape_dynamic_bytes)
{
}
TEST_CASE(test_shape_dynamic_errors) TEST_CASE(test_shape_dynamic_errors)
{ {
std::vector<migraphx::shape::dynamic_dimension> dims = {}; using migraphx::shape;
dims.emplace_back(2, 5, 2); std::vector<shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 8, 0); dims.push_back(shape::dynamic_dimension{2, 5, 2});
migraphx::shape s{migraphx::shape::float_type, dims}; dims.push_back(shape::dynamic_dimension{2, 8, 0});
EXPECT(test::throws([&] { s.element_space(); })); migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); })); EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); }));
EXPECT(test::throws([&] { s.index({0, 1}); })); EXPECT(test::throws([&] { s.index({0, 1}); }));
EXPECT(test::throws([&] { s.index(1); })); EXPECT(test::throws([&] { s.index(1); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); })); EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(migraphx::shape::float_type, {3, 5}); })); EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); }));
} }
TEST_CASE(test_shape_dynamic_serialize) TEST_CASE(test_shape_dynamic_serialize)
{ {
std::vector<migraphx::shape::dynamic_dimension> dims1 = {}; using migraphx::shape;
dims1.emplace_back(2, 5, 2); std::vector<shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(2, 8, 0); dims1.push_back(shape::dynamic_dimension{2, 5, 2});
migraphx::shape s1{migraphx::shape::float_type, dims1}; dims1.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1); auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {}; std::vector<shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(2, 5, 2); dims2.push_back(shape::dynamic_dimension{2, 5, 2});
migraphx::shape s2{migraphx::shape::uint64_type, dims2}; migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2); auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2); EXPECT(v1 != v2);
auto s3 = migraphx::from_value<migraphx::shape>(v1); auto s3 = migraphx::from_value<shape>(v1);
EXPECT(s3 == s1); EXPECT(s3 == s1);
auto s4 = migraphx::from_value<migraphx::shape>(v2); auto s4 = migraphx::from_value<shape>(v2);
EXPECT(s4 == s2); EXPECT(s4 == s2);
EXPECT(s3 != s4); EXPECT(s3 != s4);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment