Unverified Commit c65ab678 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Change check_shapes to templated class (#2011)

parent ae4cdf5a
...@@ -34,21 +34,37 @@ ...@@ -34,21 +34,37 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Check that deduced type is incrementable, dereferencable, and comparable
template <class, class = void>
struct is_iterator
{
};
template <class T>
struct is_iterator<T,
std::void_t<decltype(++std::declval<T&>()),
decltype(*std::declval<T&>()),
decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
{
};
template <class Iterator>
struct check_shapes struct check_shapes
{ {
const shape* begin; static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
const shape* end; Iterator begin;
Iterator end;
std::string name; std::string name;
bool dynamic_allowed; bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false) check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d) : begin(b), end(e), name(n), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false) check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d) : begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
...@@ -56,7 +72,7 @@ struct check_shapes ...@@ -56,7 +72,7 @@ struct check_shapes
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
...@@ -81,8 +97,6 @@ struct check_shapes ...@@ -81,8 +97,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return 0; return 0;
assert(begin != nullptr);
assert(end != nullptr);
return end - begin; return end - begin;
} }
...@@ -131,8 +145,6 @@ struct check_shapes ...@@ -131,8 +145,6 @@ struct check_shapes
*/ */
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() != n) if(begin->max_lens().size() != n)
...@@ -148,8 +160,6 @@ struct check_shapes ...@@ -148,8 +160,6 @@ struct check_shapes
*/ */
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() > n) if(begin->max_lens().size() > n)
...@@ -166,8 +176,6 @@ struct check_shapes ...@@ -166,8 +176,6 @@ struct check_shapes
*/ */
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() < n) if(begin->max_lens().size() < n)
...@@ -330,8 +338,6 @@ struct check_shapes ...@@ -330,8 +338,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
...@@ -341,8 +347,6 @@ struct check_shapes ...@@ -341,8 +347,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
...@@ -351,17 +355,13 @@ struct check_shapes ...@@ -351,17 +355,13 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return false; return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p); return std::any_of(begin, end, p);
} }
const shape* get(long i) const Iterator get(long i) const
{ {
if(i >= size()) if(i >= size())
MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds"); MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0) if(i < 0)
return end - i; return end - i;
return begin + i; return begin + i;
...@@ -394,6 +394,11 @@ struct check_shapes ...@@ -394,6 +394,11 @@ struct check_shapes
} }
}; };
// Deduction guide for std::vector constructor
template <class Op>
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
-> check_shapes<std::vector<shape>::const_iterator>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot> ...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)}; MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
} }
void required(const check_shapes& cs) const { cs.not_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
......
...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive> ...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
} }
// dnnl has some issues with non-packed inputs // dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); } std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment