Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -14,6 +14,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value;
struct shape_impl;
struct shape
......@@ -22,6 +23,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
......@@ -33,12 +35,12 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
{
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type
};
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
......@@ -57,6 +59,11 @@ struct shape
{
};
static const std::vector<type_t>& types();
static std::string name(type_t t);
static std::string cpp_type(type_t t);
shape();
shape(type_t t);
shape(type_t t, std::vector<std::size_t> l);
......@@ -75,6 +82,10 @@ struct shape
{
}
shape(const std::vector<shape>& subs);
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
......@@ -93,13 +104,14 @@ struct shape
{
assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT
}
/// Map element index to space index
std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding
bool packed() const;
......@@ -114,6 +126,13 @@ struct shape
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -121,50 +140,58 @@ struct shape
template <class T>
struct as
{
using type = T;
using type = std::conditional_t<std::is_same<T, bool>{}, int8_t, T>;
type max() const { return std::numeric_limits<type>::max(); }
type min() const { return std::numeric_limits<type>::lowest(); }
template <class U>
T operator()(U u) const
type operator()(U u) const
{
return T(u);
return type(u);
}
template <class U>
T* operator()(U* u) const
type* operator()(U* u) const
{
return static_cast<T*>(u);
return static_cast<type*>(u);
}
template <class U>
const T* operator()(const U* u) const
const type* operator()(const U* u) const
{
return static_cast<T*>(u);
return static_cast<type*>(u);
}
T operator()() const { return {}; }
type operator()() const { return {}; }
std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; }
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
template <class U>
T* from(U* buffer, std::size_t n = 0) const
type* from(U* buffer, std::size_t n = 0) const
{
return reinterpret_cast<T*>(buffer) + n;
return reinterpret_cast<type*>(buffer) + n;
}
template <class U>
const T* from(const U* buffer, std::size_t n = 0) const
const type* from(const U* buffer, std::size_t n = 0) const
{
return reinterpret_cast<const T*>(buffer) + n;
return reinterpret_cast<const type*>(buffer) + n;
}
type_t type_enum() const { return get_type<T>{}; }
type_t type_enum() const { return get_type<type>{}; }
};
template <class Visitor>
void visit_type(Visitor v) const
template <class Visitor, class TupleVisitor>
static void visit(type_t t, Visitor v, TupleVisitor tv)
{
switch(this->type())
switch(t)
{
case tuple_type: {
tv();
return;
}
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
......@@ -173,6 +200,18 @@ struct shape
MIGRAPHX_THROW("Unknown type");
}
template <class Visitor>
static void visit(type_t t, Visitor v)
{
return visit(t, v, [] { MIGRAPHX_THROW("Tuple cannot be visited."); });
}
template <class... Visitors>
void visit_type(Visitors... vs) const
{
visit(this->type(), vs...);
}
template <class Visitor>
static void visit_types(Visitor v)
{
......@@ -181,13 +220,21 @@ struct shape
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
std::shared_ptr<const shape_impl> impl;
std::string type_string() const;
static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
std::string type_string() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
};
void migraphx_to_value(value& v, const shape& s);
void migraphx_from_value(const value& v, shape& s);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -7,7 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Simplify many algebraic instructions to more efficient versions.
......@@ -15,7 +15,7 @@ struct program;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs)
*/
struct simplify_qdq
{
std::string name() const { return "simplify_qdq"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,7 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Eliminate redundant reshapes.
......@@ -16,7 +16,7 @@ struct program;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_STREAM_MODEL_HPP
#define MIGRAPHX_GUARD_STREAM_MODEL_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// An interface for target-dependent model for the scheduler
struct stream_model
{
/// Get the number of streams used in the program
std::size_t get_nstream() const;
/// Get stream for instruction
std::size_t get_stream(instruction_ref ins) const;
/// Get unique event id for instruction
std::size_t get_event_id(instruction_ref ins) const;
/// Returns true if instruction has a stream assignment
bool has_stream(instruction_ref ins) const;
/// Returns true if the instruction records the event
bool is_record(instruction_ref ins) const;
/// Returns true if the instruction wait on the event
bool is_wait(instruction_ref ins) const;
};
#else
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct stream_model
{
//
std::size_t get_nstream() const;
//
std::size_t get_stream(instruction_ref ins) const;
//
std::size_t get_event_id(instruction_ref ins) const;
//
bool has_stream(instruction_ref ins) const;
//
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model
{
// Constructors
stream_model() = default;
template <typename PrivateDetailTypeErasedT>
stream_model(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
stream_model& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
stream_model rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::size_t get_nstream() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_nstream();
}
std::size_t get_stream(instruction_ref ins) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_stream(ins);
}
std::size_t get_event_id(instruction_ref ins) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_event_id(ins);
}
bool has_stream(instruction_ref ins) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().has_stream(ins);
}
bool is_record(instruction_ref ins) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_record(ins);
}
bool is_wait(instruction_ref ins) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_wait(ins);
}
friend bool is_shared(const stream_model& private_detail_x,
const stream_model& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::size_t get_nstream() const = 0;
virtual std::size_t get_stream(instruction_ref ins) const = 0;
virtual std::size_t get_event_id(instruction_ref ins) const = 0;
virtual bool has_stream(instruction_ref ins) const = 0;
virtual bool is_record(instruction_ref ins) const = 0;
virtual bool is_wait(instruction_ref ins) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::size_t get_nstream() const override { return private_detail_te_value.get_nstream(); }
std::size_t get_stream(instruction_ref ins) const override
{
return private_detail_te_value.get_stream(ins);
}
std::size_t get_event_id(instruction_ref ins) const override
{
return private_detail_te_value.get_event_id(ins);
}
bool has_stream(instruction_ref ins) const override
{
return private_detail_te_value.has_stream(ins);
}
bool is_record(instruction_ref ins) const override
{
return private_detail_te_value.is_record(ins);
}
bool is_wait(instruction_ref ins) const override
{
return private_detail_te_value.is_wait(ins);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const stream_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(stream_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(stream_model& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const stream_model& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,11 +5,22 @@
#include <numeric>
#include <string>
#include <sstream>
#include <unordered_map>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
template <class F>
auto with_char(F f)
{
return [=](unsigned char c) -> bool { return f(c); };
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
{
......@@ -43,17 +54,29 @@ inline std::string join_strings(Strings strings, const std::string& delim)
});
}
inline std::vector<std::string> split_string(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + ' ');
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
template <class F>
std::string trim(const std::string& s, F f)
{
auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return std::string(start, last);
return {start, last};
}
inline std::string trim(const std::string& s)
{
return trim(s, [](int c) { return std::isspace(c); });
return trim(s, [](unsigned char c) { return std::isspace(c); });
}
template <class F>
......@@ -83,6 +106,44 @@ inline std::string remove_prefix(std::string s, const std::string& prefix)
return s;
}
template <class F>
inline std::string
interpolate_string(const std::string& input, F f, std::string start = "${", std::string end = "}")
{
std::string result = "";
result.reserve(input.size());
auto it = input.begin();
while(it != input.end())
{
auto next_start = std::search(it, input.end(), start.begin(), start.end());
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
result.append(it, next_start);
if(next_start == input.end())
break;
auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end());
it = next_end + end.size();
}
return result;
}
inline std::string interpolate_string(const std::string& input,
const std::unordered_map<std::string, std::string>& vars,
std::string start = "${",
std::string end = "}")
{
return interpolate_string(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Iterator>
inline std::string to_string_range(Iterator start, Iterator last)
{
......@@ -108,7 +169,8 @@ inline std::string to_string_range(const std::initializer_list<T>& r)
}
template <class T>
inline std::string to_string(const T& x)
inline auto to_string(const T& x)
-> decltype((std::declval<std::stringstream>() << x), std::string{})
{
std::stringstream ss;
ss << x;
......
......@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg;
}
/*
* Type-erased interface for:
*
* struct target
* {
* std::string name() const;
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const;
* context get_context() const;
* argument copy_to(const argument& input) const;
* argument copy_from(const argument& input) const;
* argument allocate(const shape& s) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct target
{
//
std::string name() const;
//
std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
//
context get_context() const;
// (optional)
argument copy_to(const argument& input) const;
// (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target
{
......@@ -115,11 +121,17 @@ struct target
template <typename PrivateDetailTypeErasedT>
target& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
target rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
......@@ -127,7 +139,7 @@ struct target
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -138,7 +150,7 @@ struct target
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -376,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -4,6 +4,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/config.hpp>
#include <iostream>
......@@ -20,10 +21,24 @@ T as_number(T x)
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
template <class T>
struct tensor_view_iterator_read
{
T* view;
auto& operator()(std::size_t n) const
{
assert(view != nullptr);
return (*view)[n];
}
};
template <class T>
struct tensor_view
{
using value_type = T;
using iterator = basic_iota_iterator<tensor_view_iterator_read<tensor_view<T>>, std::size_t>;
using const_iterator =
basic_iota_iterator<tensor_view_iterator_read<const tensor_view<T>>, std::size_t>;
tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
......@@ -56,12 +71,16 @@ struct tensor_view
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
assert(std::distance(start, last) > 0);
assert(std::all_of(start, last, [](auto x) { return x >= 0; }));
return m_data[m_shape.index(start, last)];
}
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
assert(std::distance(start, last) > 0);
assert(std::all_of(start, last, [](auto x) { return x >= 0; }));
return m_data[m_shape.index(start, last)];
}
......@@ -101,36 +120,13 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)];
}
// TODO: Add iterators so it can handle nonstandard tensors
T* begin()
{
assert(this->m_shape.standard() or this->empty());
return m_data;
}
iterator begin() { return {0, {this}}; }
T* end()
{
assert(this->m_shape.standard() or this->empty());
if(this->empty())
return m_data;
else
return m_data + this->size();
}
iterator end() { return {this->size(), {this}}; }
const T* begin() const
{
assert(this->m_shape.standard() or this->empty());
return m_data;
}
const_iterator begin() const { return {0, {this}}; }
const T* end() const
{
assert(this->m_shape.standard() or this->empty());
if(this->empty())
return m_data;
else
return m_data + this->size();
}
const_iterator end() const { return {this->size(), {this}}; }
template <class U = T>
std::vector<U> to_vector() const
......
......@@ -7,8 +7,18 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in tf options to parser
struct tf_options
{
bool is_nhwc = false;
unsigned int batch_size = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
std::vector<std::string> output_node_names = {};
};
/// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, bool is_nhwc);
program parse_tf(const std::string& name, const tf_options& options = tf_options{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -7,13 +7,23 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct timer
{
std::chrono::time_point<std::chrono::steady_clock> start = std::chrono::steady_clock::now();
template <class Duration>
auto record() const
{
auto finish = std::chrono::steady_clock::now();
return std::chrono::duration_cast<Duration>(finish - start).count();
}
};
template <class Duration, class F>
auto time(F f)
{
auto start = std::chrono::steady_clock::now();
timer t{};
f();
auto finish = std::chrono::steady_clock::now();
return std::chrono::duration_cast<Duration>(finish - start).count();
return t.record<Duration>();
}
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_TMP_DIR_HPP
#define MIGRAPHX_GUARD_RTGLIB_TMP_DIR_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct tmp_dir
{
fs::path path;
tmp_dir(const std::string& prefix = "");
void execute(const std::string& exe, const std::string& args) const;
tmp_dir(tmp_dir const&) = delete;
tmp_dir& operator=(tmp_dir const&) = delete;
~tmp_dir();
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR")
{
if(axis >= n_dim || std::abs(axis) > n_dim)
{
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
}
return (axis < 0) ? axis + n_dim : axis;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,30 +8,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
std::string compute_type_name()
{
static std::string name;
if(name.empty())
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
name = name.substr(begin, length);
#endif
}
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
......
......@@ -30,6 +30,12 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
double,
std::conditional_t<is_signed<T>{}, std::int64_t, std::uint64_t>>;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_VALUE_HPP
#define MIGRAPHX_GUARD_RTGLIB_VALUE_HPP
#include <migraphx/config.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/rank.hpp>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <type_traits>
#include <tuple>
#include <unordered_map>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value_base_impl;
template <class To, class = void>
struct value_converter
{
template <class T = To>
static auto apply(const std::string& x)
-> decltype((std::declval<std::stringstream&>() >> std::declval<T&>()), To{})
{
To result;
std::stringstream ss;
ss.str(x);
ss >> result;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
return result;
}
template <class From, MIGRAPHX_REQUIRES(std::is_convertible<From, To>{})>
static To apply(const From& x)
{
return To(x);
}
};
template <class To>
struct value_converter<To, MIGRAPHX_CLASS_REQUIRES(std::is_enum<To>{})>
{
template <class From>
static auto apply(const From& x)
-> decltype(static_cast<To>(value_converter<std::underlying_type_t<To>>::apply(x)))
{
return static_cast<To>(value_converter<std::underlying_type_t<To>>::apply(x));
}
};
template <>
struct value_converter<std::string>
{
static const std::string& apply(const std::string& x) { return x; }
template <class From>
static auto apply(const From& x)
-> decltype(std::declval<std::stringstream&>() << x, std::string())
{
std::stringstream ss;
ss << x;
if(ss.fail())
throw std::runtime_error("Failed to parse");
return ss.str();
}
};
template <class T, class U>
struct value_converter<std::pair<T, U>>
{
template <class Key, class From>
static auto apply(const std::pair<Key, From>& x)
-> decltype(std::pair<T, U>(x.first, value_converter<U>::apply(x.second)))
{
return std::pair<T, U>(x.first, value_converter<U>::apply(x.second));
}
};
template <class To, class From>
To try_convert_value(const From& x);
namespace detail {
template <class To, class Key, class From>
To try_convert_value_impl(rank<1>, const std::pair<Key, From>& x)
{
return try_convert_value<To>(x.second);
}
template <class To, class From>
auto try_convert_value_impl(rank<2>, const From& x) -> decltype(value_converter<To>::apply(x))
{
return value_converter<To>::apply(x);
}
template <class To, MIGRAPHX_REQUIRES(not std::is_same<To, std::nullptr_t>{})>
To try_convert_value_impl(rank<3>, std::nullptr_t)
{
MIGRAPHX_THROW("Incompatible values: null -> " + get_type_name<To>());
}
template <class To, class From>
To try_convert_value_impl(rank<0>, const From& x)
{
MIGRAPHX_THROW("Incompatible values: " + get_type_name(x) + " -> " + get_type_name<To>());
}
} // namespace detail
template <class To, class From>
To try_convert_value(const From& x)
{
return detail::try_convert_value_impl<To>(rank<3>{}, x);
}
struct value
{
// clang-format off
#define MIGRAPHX_VISIT_VALUE_TYPES(m) \
m(int64, std::int64_t) \
m(uint64, std::uint64_t) \
m(float, double) \
m(string, std::string) \
m(bool, bool) \
m(binary, value::binary)
// clang-format on
enum type_t
{
#define MIGRAPHX_VALUE_GENERATE_ENUM_TYPE(vt, cpp_type) vt##_type,
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_ENUM_TYPE) object_type,
array_type,
null_type
#undef MIGRAPHX_VALUE_GENERATE_ENUM_TYPE
};
using iterator = value*;
using const_iterator = const value*;
using value_type = value;
using key_type = std::string;
using mapped_type = value;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using array = std::vector<value>;
using object = std::unordered_map<std::string, value>;
struct binary : std::vector<std::uint8_t>
{
using base = std::vector<std::uint8_t>;
binary() {}
template <class Container,
MIGRAPHX_REQUIRES(sizeof(*std::declval<Container>().begin()) == 1)>
explicit binary(const Container& c) : base(c.begin(), c.end())
{
}
template <class T>
binary(T* data, std::size_t s) : base(data, data + s)
{
}
explicit binary(std::size_t s) : base(s) {}
};
value() = default;
value(const value& rhs);
value& operator=(value rhs);
value(const std::string& pkey, const value& rhs);
value(const std::initializer_list<value>& i);
value(const std::vector<value>& v, bool array_on_empty = true);
value(const std::unordered_map<std::string, value>& m);
value(const std::string& pkey, const std::vector<value>& v, bool array_on_empty = true);
value(const std::string& pkey, const std::unordered_map<std::string, value>& m);
value(const std::string& pkey, std::nullptr_t);
value(std::nullptr_t);
value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \
value(const std::string& pkey, cpp_type i); \
value& operator=(cpp_type rhs); \
bool is_##vt() const; \
const cpp_type& get_##vt() const; \
const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T>
using pick_numeric = std::conditional_t<
std::is_floating_point<T>{},
double,
std::conditional_t<std::is_signed<T>{},
std::int64_t,
std::conditional_t<std::is_unsigned<T>{}, std::uint64_t, T>>>;
template <class T>
using pick = pick_numeric<typename std::conditional_t<std::is_enum<T>{},
std::underlying_type<T>,
std::enable_if<true, T>>::type>;
template <class T>
using is_pickable =
bool_c<((std::is_arithmetic<T>{} or std::is_enum<T>{}) and not std::is_pointer<T>{})>;
template <class T>
using range_value = std::decay_t<decltype(std::declval<T>().end(), *std::declval<T>().begin())>;
template <class T>
using is_generic_range =
bool_c<(std::is_convertible<range_value<T>, value>{} and
not std::is_convertible<T, array>{} and not std::is_convertible<T, object>{})>;
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const T& r) : value(from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const std::string& pkey, const T& r) : value(pkey, from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value(T i) : value(static_cast<pick<T>>(i))
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value(const std::string& pkey, T i) : value(pkey, static_cast<pick<T>>(i))
{
}
template <class T, class U, class = decltype(value(T{}, U{}))>
value(const std::pair<T, U>& p) : value(p.first, p.second)
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value& operator=(T rhs)
{
return *this = static_cast<pick<T>>(rhs); // NOLINT
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value& operator=(T rhs)
{
return *this = from_values(rhs); // NOLINT
}
value& operator=(const char* c);
value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
bool is_array() const;
const std::vector<value>& get_array() const;
const std::vector<value>* if_array() const;
bool is_object() const;
const std::vector<value>& get_object() const;
const std::vector<value>* if_object() const;
bool is_null() const;
const std::string& get_key() const;
value* find(const std::string& pkey);
const value* find(const std::string& pkey) const;
bool contains(const std::string& pkey) const;
std::size_t size() const;
bool empty() const;
const value* data() const;
value* data();
value* begin();
const value* begin() const;
value* end();
const value* end() const;
value& front();
const value& front() const;
value& back();
const value& back() const;
value& at(std::size_t i);
const value& at(std::size_t i) const;
value& at(const std::string& pkey);
const value& at(const std::string& pkey) const;
value& operator[](std::size_t i);
const value& operator[](std::size_t i) const;
value& operator[](const std::string& pkey);
void clear();
void resize(std::size_t n);
void resize(std::size_t n, const value& v);
std::pair<value*, bool> insert(const value& v);
value* insert(const value* pos, const value& v);
template <class... Ts>
std::pair<value*, bool> emplace(Ts&&... xs)
{
return insert(value(std::forward<Ts>(xs)...));
}
template <class... Ts>
value* emplace(const value* pos, Ts&&... xs)
{
return insert(pos, value(std::forward<Ts>(xs)...));
}
void push_back(const value& v) { insert(end(), v); }
void push_front(const value& v) { insert(begin(), v); }
value with_key(const std::string& pkey) const;
value without_key() const;
template <class Visitor>
void visit(Visitor v) const
{
switch(this->get_type())
{
case null_type: {
std::nullptr_t null{};
if(this->key.empty())
v(null);
else
v(std::make_pair(this->get_key(), std::ref(null)));
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: { \
if(this->key.empty()) \
v(this->get_##vt()); \
else \
v(std::make_pair(this->get_key(), std::ref(this->get_##vt()))); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
MIGRAPHX_THROW("Unknown type");
}
// Visit value without key
template <class Visitor>
void visit_value(Visitor v) const
{
switch(this->get_type())
{
case null_type: {
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: { \
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
MIGRAPHX_THROW("Unknown type");
}
template <class To>
To to() const
{
To result;
this->visit([&](auto y) { result = try_convert_value<To>(y); });
return result;
}
template <class To>
literal_to_string<To> value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<literal_to_string<To>>();
}
template <class To>
std::vector<To> to_vector() const
{
std::vector<To> result;
const auto& values = is_object() ? get_object() : get_array();
result.reserve(values.size());
std::transform(values.begin(), values.end(), std::back_inserter(result), [&](auto v) {
return v.template to<To>();
});
return result;
}
template <class To>
literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to<literal_to_string<To>>();
}
template <class To>
std::vector<To> get(const std::string& pkey, const std::vector<To>& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to_vector<To>();
}
template <class To>
std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
{
return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
}
friend bool operator==(const value& x, const value& y);
friend bool operator!=(const value& x, const value& y);
friend bool operator<(const value& x, const value& y);
friend bool operator<=(const value& x, const value& y);
friend bool operator>(const value& x, const value& y);
friend bool operator>=(const value& x, const value& y);
friend std::ostream& operator<<(std::ostream& os, const value& d);
void debug_print(bool show_type = false) const;
type_t get_type() const;
private:
template <class T>
std::vector<value> from_values(const T& r)
{
std::vector<value> v;
std::transform(
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
std::shared_ptr<value_base_impl> x;
std::string key;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -147,7 +147,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
}
template <class R1, class R2>
double rms_range(R1&& r1, R2&& r2)
double rms_range(const R1& r1, const R2& r2)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
......@@ -164,11 +164,10 @@ double rms_range(R1&& r1, R2&& r2)
}
template <class R1, class R2>
bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = nullptr)
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr)
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr)
*out_error = error;
return error <= threshold;
......
......@@ -8,81 +8,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name,
const argument& cpu_arg,
const argument& gpu_arg,
double tolerance = 80)
{
bool passed = true;
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
double error;
passed = verify_range(cpu, gpu, tolerance, &error);
if(not passed)
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl;
if(cpu.size() < 32)
std::cout << "cpu:" << cpu << std::endl;
if(gpu.size() < 32)
std::cout << "gpu:" << gpu << std::endl;
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
auto mxdiff = max_diff(cpu, gpu);
std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(cpu, gpu, float_equal);
if(idx < range_distance(cpu))
{
std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
<< std::endl;
}
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl;
}
else
{
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
// auto mxdiff = max_diff(cpu, gpu);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(cpu, gpu, float_equal);
// if(idx < range_distance(cpu))
// {
// std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
// << std::endl;
// }
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
// std::cout << std::endl;
}
});
return passed;
}
bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
double tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/insert_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = ins->get_operator();
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
op_padding.begin() + kdims,
op_padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
op_padding = std::vector<size_t>(kdims * 2, 0);
op.from_value({{"padding", op_padding}});
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == op::pooling_mode::average)
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
// maxpool uses lowest value for padding
float pad_val = std::numeric_limits<float>::lowest();
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
void insert_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto equal_to(const T& x)
{
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction::instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules)
: op(std::move(o)),
result(std::move(r)),
arguments(std::move(args)),
module_args(std::move(modules))
{
}
instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
......@@ -22,6 +41,9 @@ void instruction::replace(const shape& r)
result = r;
for(auto&& ins : output)
{
if(ins->name() == "@return")
continue;
assert(ins->name().front() != '@');
ins->recompute_shape();
}
......@@ -30,11 +52,12 @@ void instruction::replace(const shape& r)
void instruction::replace(operation o)
{
op = std::move(o);
normalized = false;
op = std::move(o);
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
void instruction::clear_arguments()
{
......@@ -43,6 +66,7 @@ void instruction::clear_arguments()
arg->remove_output(*this);
}
arguments.clear();
module_args.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
......@@ -50,12 +74,17 @@ bool operator==(const instruction& i, instruction_ref ref)
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start) const
bool instruction::valid(instruction_ref start, bool check_order) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
bool ret = self != i->outputs().end();
if(check_order)
{
// check arguments for this instruction before this instruction
ret = ret and (std::distance(start, i) < std::distance(start, *self));
}
return ret;
});
}
......@@ -70,18 +99,24 @@ bool instruction::valid() const
{
computed = result;
}
else if(op.name() == "@return")
{
computed = {};
}
else
{
try
{
computed = compute_shape(op, arguments);
computed = compute_shape(op, arguments, module_args);
}
catch(migraphx::exception&)
{
return false;
}
}
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return (result == computed) &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
......@@ -99,11 +134,19 @@ std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y)
{
if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
if(not std::equal(x.arguments.begin(),
x.arguments.end(),
y.arguments.begin(),
y.arguments.end(),
std::equal_to<instruction_ref>{}))
return false;
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
......@@ -120,7 +163,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref);
void instruction::add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
output.push_back(ins);
}
......@@ -139,6 +182,13 @@ void instruction::replace_argument(instruction_ref ins,
ins->recompute_shape();
}
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
ins->replace_mod_argument(old, new_mod);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
......@@ -148,26 +198,87 @@ void instruction::replace(instruction_ref ins,
backreference(ins);
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
ins->replace(std::move(o), r, std::move(args), std::move(module_args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
normalized = false;
op = std::move(o);
replace(r);
replace(std::move(args));
}
void instruction::replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args)
{
op = std::move(o);
replace(r);
replace(std::move(args), std::move(mdl_args));
}
void instruction::replace_refs(
instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods)
{
const auto& args = ins->inputs();
for(const auto& arg : args)
{
if(contains(map_insts, arg))
{
instruction::replace_argument(ins, arg, map_insts.at(arg));
}
}
const auto& module_args = ins->module_inputs();
if(module_args.empty())
return;
for(const auto& mod : module_args)
{
if(contains(map_mods, mod))
{
instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
}
}
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
clear_arguments();
arguments = std::move(args);
module_args = std::move(mdl_args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
std::replace(arguments.begin(), arguments.end(), old, new_ins);
assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
old->remove_output(*this);
}
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
std::replace(module_args.begin(), module_args.end(), old, new_mod);
}
bool instruction::can_eval() const
{
if(op.name() == "@literal")
......@@ -200,7 +311,7 @@ argument instruction::eval(bool check_eval) const
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return op.compute(result, args);
return normalized_operator().compute(result, args);
}
return {};
}
......@@ -211,6 +322,82 @@ void instruction::finalize(context& ctx)
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
void instruction::print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
os << delim << arg_name;
delim = ',';
}
os << ")";
}
// print module inputs
if(!ins->module_inputs().empty())
{
std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs())
{
os << delim << mod_arg->name();
delim = ", ";
}
os << "]";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
static void debug_name(std::ostream& os, const instruction& ins)
{
if(ins.name() == "@literal")
{
os << "@literal";
if(ins.get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins.get_literal() << "}";
}
else
{
os << ins.get_operator();
}
}
void instruction::debug_print() const
{
debug_name(std::cout, *this);
std::string delim = "(";
for(auto arg : this->inputs())
{
std::cout << delim;
debug_name(std::cout, *arg);
delim = ", ";
}
if(not this->inputs().empty())
std::cout << ")";
std::cout << " -> " << this->get_shape() << std::endl;
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
......@@ -221,6 +408,27 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
return get_output_alias(ins->inputs().at(i));
}
void instruction::set_normalized(bool value) { normalized = value; }
bool instruction::is_normalized() const { return normalized; }
bool instruction::need_normalization() const
{
return this->get_operator().need_normalization() and not normalized;
}
operation instruction::normalized_operator() const
{
operation o = this->get_operator();
if(this->need_normalization())
{
auto lens = this->inputs().front()->get_shape().lens();
if(!normalize_attributes(o, lens))
return this->get_operator();
}
return o;
}
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
......@@ -234,5 +442,38 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(to_shapes(args));
}
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
if(mods.empty())
{
return op.compute_shape(to_shapes(args));
}
else
{
return op.compute_shape(to_shapes(args), mods);
}
}
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
{
shape new_shape;
try
{
new_shape = op.compute_shape(inputs);
}
catch(...)
{
return {};
}
return {new_shape};
}
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
return std::addressof(*ins);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment