Commit d2549384 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 67048d04 ab6cd9d3
#include <migraph/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraph/functional.hpp> #include <migraphx/functional.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range> template <class Range>
void cse_range(program& p, Range&& r) void cse_range(program& p, Range&& r)
...@@ -35,5 +35,5 @@ void cse_range(program& p, Range&& r) ...@@ -35,5 +35,5 @@ void cse_range(program& p, Range&& r)
void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); } void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/constant_propagate.hpp> #include <migraphx/constant_propagate.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct match_const_add struct match_const_add
{ {
...@@ -13,7 +13,7 @@ struct match_const_add ...@@ -13,7 +13,7 @@ struct match_const_add
return match::name("add")(match::args(match::name("@literal"), match::name("@literal"))); return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto arg1 = ins->inputs().at(0)->get_literal(); auto arg1 = ins->inputs().at(0)->get_literal();
...@@ -26,5 +26,5 @@ struct match_const_add ...@@ -26,5 +26,5 @@ struct match_const_add
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); } void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/functional.hpp> #include <migraphx/functional.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator> template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
...@@ -62,5 +62,5 @@ void dead_code_elimination::apply(program& p) const ...@@ -62,5 +62,5 @@ void dead_code_elimination::apply(program& p) const
p.remove_instructions(std::next(last), p.end()); p.remove_instructions(std::next(last), p.end());
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraph/pass_config.hpp> #include <migraphx/pass_config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(program& p) const
{ {
assert(alignment > 0); assert(alignment > 0);
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
return;
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
...@@ -27,15 +25,18 @@ void eliminate_allocation::apply(program& p) const ...@@ -27,15 +25,18 @@ void eliminate_allocation::apply(program& p) const
std::size_t padding = (alignment - (size % alignment)) % alignment; std::size_t padding = (alignment - (size % alignment)) % alignment;
n += size + padding; n += size + padding;
} }
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); if(n > 0)
for(auto&& pp : allocs)
{ {
auto ins = pp.first; auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
auto s = ins->get_shape(); for(auto&& pp : allocs)
auto offset = pp.second; {
p.replace_instruction(ins, op::load{s, offset}, mem); auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, op::load{s, offset}, mem);
}
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <iterator> #include <iterator>
#include <migraph/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(program& p) const void eliminate_concat::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -36,14 +36,17 @@ void eliminate_concat::apply(program& p) const ...@@ -36,14 +36,17 @@ void eliminate_concat::apply(program& p) const
// Where are the allocations for the tensors to be concatenated? // Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations; std::vector<instruction_ref> allocations;
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++) std::transform(
{ ins->inputs().begin(),
auto last2 = (*ins2)->inputs().back(); std::prev(ins->inputs().end()),
if(last2->name() == concat_opt.allocate()) std::back_inserter(allocations),
{ [&](instruction_ref x) { return instruction::get_output_alias(x, true); });
allocations.push_back(last2);
} if(std::any_of(allocations.begin(), allocations.end(), [&](auto x) {
} return x->name() != concat_opt.allocate();
}))
continue;
// Need to sort the allocations, so that we know where to // Need to sort the allocations, so that we know where to
// insert the "super"-allocation // insert the "super"-allocation
std::sort( std::sort(
...@@ -51,21 +54,21 @@ void eliminate_concat::apply(program& p) const ...@@ -51,21 +54,21 @@ void eliminate_concat::apply(program& p) const
return std::distance(p.begin(), x) < std::distance(p.begin(), y); return std::distance(p.begin(), x) < std::distance(p.begin(), y);
}); });
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = allocations.front(); auto first = allocations.front();
auto super = p.move_instruction(last, first); auto super = p.move_instruction(last, first);
// Replace each allocation with a load
std::size_t offset = 0; std::size_t offset = 0;
for(auto x : allocations) for(auto alloc : allocations)
{ {
migraph::op::load op{x->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
// migraph::op::load op{x->get_shape(), 0}; p.replace_instruction(alloc, op, {super});
p.replace_instruction(x, op, {super}); offset += alloc->get_shape().bytes();
offset += x->get_shape().bytes();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args); p.replace_instruction(ins, migraphx::op::identity{}, args);
} }
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args) bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
...@@ -47,5 +47,5 @@ void eliminate_contiguous::apply(program& p) const ...@@ -47,5 +47,5 @@ void eliminate_contiguous::apply(program& p) const
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/env.hpp> #include <migraphx/env.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <cstdlib> #include <cstdlib>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool enabled(const char* name) bool enabled(const char* name)
{ {
...@@ -30,5 +30,5 @@ std::vector<std::string> env(const char* name) ...@@ -30,5 +30,5 @@ std::vector<std::string> env(const char* name)
return {{p}}; return {{p}};
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
...@@ -67,5 +67,5 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -67,5 +67,5 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
argument generate_argument(shape s, unsigned long seed) argument generate_argument(shape s, unsigned long seed)
{ {
...@@ -31,5 +31,5 @@ literal abs(literal l) ...@@ -31,5 +31,5 @@ literal abs(literal l)
return transform(std::move(l), [](auto x) { return std::fabs(x); }); return transform(std::move(l), [](auto x) { return std::fabs(x); });
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#ifndef MIGRAPH_GUARD_CONFIG_HPP
#define MIGRAPH_GUARD_CONFIG_HPP
namespace migraph {
#if !defined(MIGRAPH_USE_CLANG_TIDY) && !defined(DOXYGEN)
#define MIGRAPH_INLINE_NS version_1
#endif
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPH_FALLTHROUGH
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
/// Create a program from an onnx file
program parse_onnx(const std::string& name);
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_PASS_CONFIG_HPP
#define MIGRAPH_GUARD_PASS_CONFIG_HPP
#include <migraph/env.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_DISABLE_MEMORY_COLORING)
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif // MIGRAPH_GUARD_PASS_CONFIG_HPP
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ARGUMENT_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ARGUMENT_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ARGUMENT_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_ARGUMENT_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/raw_data.hpp> #include <migraphx/raw_data.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <functional> #include <functional>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/** /**
* @brief Arguments passed to instructions * @brief Arguments passed to instructions
...@@ -47,7 +47,7 @@ struct argument : raw_data<argument> ...@@ -47,7 +47,7 @@ struct argument : raw_data<argument>
shape m_shape; shape m_shape;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP #define MIGRAPHX_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
...@@ -34,7 +34,7 @@ detail::auto_any_caster<T> auto_any_cast(T& x) ...@@ -34,7 +34,7 @@ detail::auto_any_caster<T> auto_any_cast(T& x)
return {x}; return {x};
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP #define MIGRAPHX_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -16,7 +16,7 @@ struct auto_contiguous ...@@ -16,7 +16,7 @@ struct auto_contiguous
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_BUILTIN_HPP #ifndef MIGRAPHX_GUARD_BUILTIN_HPP
#define MIGRAPH_GUARD_BUILTIN_HPP #define MIGRAPHX_GUARD_BUILTIN_HPP
#include <migraph/context.hpp> #include <migraphx/context.hpp>
#include <migraph/errors.hpp> #include <migraphx/errors.hpp>
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace builtin { namespace builtin {
struct literal struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); } shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("builtin"); MIGRAPHX_THROW("builtin");
} }
}; };
...@@ -36,7 +36,7 @@ struct outline ...@@ -36,7 +36,7 @@ struct outline
shape compute_shape(const std::vector<shape>&) const { return s; } shape compute_shape(const std::vector<shape>&) const { return s; }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("builtin"); MIGRAPHX_THROW("builtin");
} }
}; };
...@@ -51,10 +51,10 @@ struct param ...@@ -51,10 +51,10 @@ struct param
} }
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); } shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("builtin"); MIGRAPHX_THROW("builtin");
} }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
...@@ -64,7 +64,7 @@ struct param ...@@ -64,7 +64,7 @@ struct param
}; };
} // namespace builtin } // namespace builtin
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP #define MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct check_context struct check_context
...@@ -15,11 +15,19 @@ struct check_context ...@@ -15,11 +15,19 @@ struct check_context
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
shape compute_shape(const std::vector<shape>&) const { return {}; } shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
this->check(ctx);
return {};
}
void finalize(context& ctx, const shape&, const std::vector<shape>&) const
{
this->check(ctx);
}
void check(context& ctx) const
{ {
T* x = any_cast<T>(&ctx); T* x = any_cast<T>(&ctx);
if(x == nullptr) if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name()); MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
} }
}; };
...@@ -27,7 +35,7 @@ struct check_context ...@@ -27,7 +35,7 @@ struct check_context
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); } void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct check_shapes struct check_shapes
{ {
...@@ -46,8 +46,8 @@ struct check_shapes ...@@ -46,8 +46,8 @@ struct check_shapes
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
if(size() != n) if(size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(size())); " but given " + std::to_string(size()));
return *this; return *this;
} }
...@@ -58,7 +58,7 @@ struct check_shapes ...@@ -58,7 +58,7 @@ struct check_shapes
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() != n) if(begin->lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
} }
...@@ -66,56 +66,56 @@ struct check_shapes ...@@ -66,56 +66,56 @@ struct check_shapes
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match"); MIGRAPHX_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match"); MIGRAPHX_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(!this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPH_THROW(prefix() + "Shapes are not in standard layout"); MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
return *this; return *this;
} }
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(!this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPH_THROW(prefix() + "Shapes are not packed"); MIGRAPHX_THROW(prefix() + "Shapes are not packed");
return *this; return *this;
} }
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPH_THROW(prefix() + "Shapes are transposed"); MIGRAPHX_THROW(prefix() + "Shapes are transposed");
return *this; return *this;
} }
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPH_THROW(prefix() + "Shapes are broadcasted"); MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
return *this; return *this;
} }
...@@ -143,7 +143,7 @@ struct check_shapes ...@@ -143,7 +143,7 @@ struct check_shapes
const shape* get(long i) const shape* get(long i)
{ {
if(i >= size()) if(i >= size())
MIGRAPH_THROW(prefix() + "Accessing shape out of bounds"); MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(i < 0) if(i < 0)
...@@ -156,7 +156,7 @@ struct check_shapes ...@@ -156,7 +156,7 @@ struct check_shapes
check_shapes slice(long start, long last) { return {get(start), get(last), name}; } check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment