Unverified Commit c65ab678 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Change check_shapes to templated class (#2011)

parent ae4cdf5a
......@@ -34,21 +34,37 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Check that deduced type is incrementable, dereferencable, and comparable
template <class, class = void>
struct is_iterator
{
};
template <class T>
struct is_iterator<T,
std::void_t<decltype(++std::declval<T&>()),
decltype(*std::declval<T&>()),
decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
{
};
template <class Iterator>
struct check_shapes
{
const shape* begin;
const shape* end;
static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
Iterator begin;
Iterator end;
std::string name;
bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d)
{
check_dynamic();
}
template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
......@@ -56,7 +72,7 @@ struct check_shapes
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
: begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
}
......@@ -81,8 +97,6 @@ struct check_shapes
{
if(begin == end)
return 0;
assert(begin != nullptr);
assert(end != nullptr);
return end - begin;
}
......@@ -131,8 +145,6 @@ struct check_shapes
*/
const check_shapes& only_dims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->max_lens().size() != n)
......@@ -148,8 +160,6 @@ struct check_shapes
*/
const check_shapes& max_ndims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->max_lens().size() > n)
......@@ -166,8 +176,6 @@ struct check_shapes
*/
const check_shapes& min_ndims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->max_lens().size() < n)
......@@ -330,8 +338,6 @@ struct check_shapes
{
if(begin == end)
return true;
assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; });
}
......@@ -341,8 +347,6 @@ struct check_shapes
{
if(begin == end)
return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::all_of(begin, end, p);
}
......@@ -351,17 +355,13 @@ struct check_shapes
{
if(begin == end)
return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const
Iterator get(long i) const
{
if(i >= size())
MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0)
return end - i;
return begin + i;
......@@ -394,6 +394,11 @@ struct check_shapes
}
};
// Deduction guide for std::vector constructor
template <class Op>
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
-> check_shapes<std::vector<shape>::const_iterator>;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
......
......@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
// dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment