Commit b43f4184 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add dynamic dimension in the shape class

parent 8e4d622f
...@@ -54,6 +54,20 @@ struct shape ...@@ -54,6 +54,20 @@ struct shape
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE)
#undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE #undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE
// dynamic dimension to support dynamic shape
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
// is the dimension fixed
bool is_fixed() const;
// does the dimension have an optimal size
bool has_optimal() const;
};
template <class T> template <class T>
struct get_type<const T> : get_type<T> struct get_type<const T> : get_type<T>
{ {
...@@ -84,6 +98,9 @@ struct shape ...@@ -84,6 +98,9 @@ struct shape
shape(const std::vector<shape>& subs); shape(const std::vector<shape>& subs);
// constructor for dynamic shape
shape(type_t t, std::vector<dynamic_dimension> dims);
static shape static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm); from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
......
...@@ -50,6 +50,7 @@ struct shape_impl ...@@ -50,6 +50,7 @@ struct shape_impl
std::vector<std::size_t> m_strides = {}; std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
std::vector<shape::dynamic_dimension> dynamic_dims = {};
void calculate_strides() void calculate_strides()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment