Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
/** /**
* Replace instructions which take all literals with a literal of the computation. * Replace instructions which take all literals with a literal of the computation.
...@@ -15,7 +15,7 @@ struct program; ...@@ -15,7 +15,7 @@ struct program;
struct propagate_constant struct propagate_constant
{ {
std::string name() const { return "propagate_constant"; } std::string name() const { return "propagate_constant"; }
void apply(program& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -17,32 +17,10 @@ struct program; ...@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"}); void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<program::parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"}); const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* quantize a program to fp16
*/
struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -2,8 +2,13 @@ ...@@ -2,8 +2,13 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm> #include <algorithm>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -33,6 +38,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,6 +38,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
template <class C, class T>
auto generic_find_at_impl(rank<1>, C&& c, const T& x) -> decltype(c.find(x))
{
return c.find(x);
}
template <class C, class T>
auto generic_find_at_impl(rank<0>, C&& c, const T& x)
{
auto n = std::distance(c.begin(), c.end());
if(x >= n)
return c.end();
return std::next(c.begin(), x);
}
template <class C, class T, class = typename C::mapped_type>
decltype(auto) generic_at_impl(rank<1>, const C&, T&& it)
{
return it->second;
}
template <class C, class T>
decltype(auto) generic_at_impl(rank<0>, const C&, T&& it)
{
return *it;
}
struct empty struct empty
{ {
}; };
...@@ -45,6 +77,20 @@ auto generic_find(C&& c, const T& x) ...@@ -45,6 +77,20 @@ auto generic_find(C&& c, const T& x)
return detail::generic_find_impl(rank<2>{}, c, x); return detail::generic_find_impl(rank<2>{}, c, x);
} }
template <class C, class T>
decltype(auto) at(C&& c, const T& x, const std::string& msg = "")
{
auto it = detail::generic_find_at_impl(rank<2>{}, c, x);
if(it == c.end())
{
if(msg.empty())
MIGRAPHX_THROW("At operator out of range for " + get_type_name(c));
else
MIGRAPHX_THROW(msg);
}
return detail::generic_at_impl(rank<2>{}, c, it);
}
template <class C, class T> template <class C, class T>
bool contains(const C& c, const T& x) bool contains(const C& c, const T& x)
{ {
...@@ -123,12 +169,41 @@ void copy(Range&& r, Iterator it) ...@@ -123,12 +169,41 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
template <class Range, class Iterator, class F>
void transform(Range&& r, Iterator it, F f)
{
std::transform(r.begin(), r.end(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
}
template <class Range, class T> template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x) void replace(Range&& r, const T& old, const T& new_x)
{ {
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R1, class R2>
bool equal(R1&& r1, R2&& r2)
{
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end());
}
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
template <class Range, class Predicate>
std::vector<range_value<Range>> find_all(Range&& r, Predicate p)
{
std::vector<range_value<Range>> result;
std::copy_if(r.begin(), r.end(), std::back_inserter(result), p);
return result;
}
template <class Iterator> template <class Iterator>
struct iterator_range struct iterator_range
{ {
...@@ -140,12 +215,18 @@ struct iterator_range ...@@ -140,12 +215,18 @@ struct iterator_range
Iterator end() const { return last; } Iterator end() const { return last; }
}; };
template <class Iterator> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
iterator_range<Iterator> range(Iterator start, Iterator last) iterator_range<Iterator> range(Iterator start, Iterator last)
{ {
return {start, last}; return {start, last};
} }
inline iterator_range<iota_iterator> range(std::ptrdiff_t start, std::ptrdiff_t last)
{
return {{start, {}}, {last, {}}};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t last) { return range(0, last); }
template <class Iterator> template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p) iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <sstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -28,7 +29,15 @@ struct raw_data : raw_data_base ...@@ -28,7 +29,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
if(not d.empty()) if(not d.empty())
d.visit([&](auto x) { os << x; }); d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os; return os;
} }
...@@ -44,9 +53,19 @@ struct raw_data : raw_data_base ...@@ -44,9 +53,19 @@ struct raw_data : raw_data_base
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty()) if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!"); MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape(); auto&& s = derived.get_shape();
auto&& buffer = derived.data(); s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); });
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); }
template <class Visitor, class TupleVisitor>
void visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); },
[&] { tv(derived.get_sub_objects()); });
} }
/** /**
...@@ -59,12 +78,7 @@ struct raw_data : raw_data_base ...@@ -59,12 +78,7 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto&& derived = static_cast<const Derived&>(*this); visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
/// Returns true if the raw data is only one element /// Returns true if the raw data is only one element
...@@ -141,50 +155,41 @@ struct raw_data : raw_data_base ...@@ -141,50 +155,41 @@ struct raw_data : raw_data_base
template <class T> template <class T>
T* cast() const T* cast() const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data(); auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraphx::shape::get_type<T>{}); assert(static_cast<const Derived&>(*this).get_shape().type() ==
migraphx::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer); return reinterpret_cast<T*>(buffer);
} }
};
template <class T, std::string to_string() const
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{ {
auto&& xbuffer = x.data(); std::stringstream ss;
auto&& ybuffer = y.data(); ss << static_cast<const Derived&>(*this);
// TODO: Dont use tensor view for single values return ss.str();
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
} }
return result; };
namespace detail {
template <class V1, class V2, class... Ts>
void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs)
{
s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); },
[&] { v2(xs.get_sub_objects()...); });
} }
template <class T, template <class V1, class V2, class... Ts>
class U, auto visit_all_pack(const shape& s, V1&& v1, V2&& v2)
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{ {
return !(x == y); return [&](auto&&... xs) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_all_flatten(s, v1, v2, xs...);
};
} }
namespace detail { template <class V1, class... Ts>
template <class V, class... Ts> auto visit_all_pack(const shape& s, V1&& v1)
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{ {
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); }); return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); });
} }
} // namespace detail } // namespace detail
...@@ -206,10 +211,7 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -206,10 +211,7 @@ auto visit_all(T&& x, Ts&&... xs)
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
return [&](auto v) { return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
} }
template <class T> template <class T>
...@@ -231,6 +233,34 @@ auto visit_all(const std::vector<T>& x) ...@@ -231,6 +233,34 @@ auto visit_all(const std::vector<T>& x)
}; };
} }
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
return result;
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<shape> reduce_dims(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -102,6 +102,29 @@ void reflect_each(T& x, F f) ...@@ -102,6 +102,29 @@ void reflect_each(T& x, F f)
}); });
} }
template <class T>
struct reflect_equality
{
friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); }
friend bool operator!=(const T& x, const T& y) { return !(x == y); }
};
template <class T>
struct reflect_stream
{
template <class Stream>
friend Stream& operator<<(Stream& os, const T& x)
{
char d = '{';
reflect_each(x, [&](const auto& y, const auto& name) {
os << d << name << "=" << y;
d = ',';
});
os << "}";
return os;
}
};
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_op(const operation& op);
operation load_op(const std::string& name);
bool has_op(const std::string& name);
std::vector<std::string> get_operators();
template <class T>
void register_op()
{
register_op(T{});
}
struct register_op_action
{
template <class T>
static void apply()
{
register_op<T>();
}
};
template <class T>
using auto_register_op = auto_register<register_op_action, T>;
#define MIGRAPHX_REGISTER_OP(...) MIGRAPHX_AUTO_REGISTER(register_op_action, __VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#include <migraphx/config.hpp>
#include <migraphx/target.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_target(const target& t);
target make_target(const std::string& name);
std::vector<std::string> get_targets();
template <class T>
void register_target()
{
register_target(T{});
}
struct register_target_action
{
template <class T>
static void apply()
{
register_target<T>();
}
};
template <class T>
using auto_register_target = auto_register<register_target_action, T>;
#define MIGRAPHX_REGISTER_TARGET(...) MIGRAPHX_AUTO_REGISTER(register_target_action, __VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP
#define MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP
#include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct replace_allocate
{
allocation_model model;
bool offload_copy = false;
std::string name() const { return "replace_allocate"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -22,12 +22,14 @@ using bool_c = std::integral_constant<bool, B>; ...@@ -22,12 +22,14 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#define MIGRAPHX_CLASS_REQUIRES(...) void
#else #else
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
long MIGRAPHX_REQUIRES_VAR() = __LINE__, \ long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \ (migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0 int>::type = 0
#define MIGRAPHX_CLASS_REQUIRES(...) typename std::enable_if<(migraphx::and_<__VA_ARGS__>{})>::type
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
...@@ -16,7 +16,7 @@ struct program; ...@@ -16,7 +16,7 @@ struct program;
struct rewrite_batchnorm struct rewrite_batchnorm
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
/** /**
* Rewrite pooling to reduce_mean * Rewrite pooling to reduce_mean
...@@ -15,7 +15,7 @@ struct program; ...@@ -15,7 +15,7 @@ struct program;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
/** /**
* Rewrite rnn to gemm and add. * Rewrite rnn to gemm and add.
...@@ -18,26 +19,22 @@ struct program; ...@@ -18,26 +19,22 @@ struct program;
struct rewrite_rnn struct rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const; void apply(module& m) const;
private: private:
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const; void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog, module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref input, std::vector<instruction_ref> inputs,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(program& prog, instruction_ref ins) const; void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -47,9 +44,9 @@ struct rewrite_rnn ...@@ -47,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators // for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const; void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -57,6 +54,27 @@ struct rewrite_rnn ...@@ -57,6 +54,27 @@ struct rewrite_rnn
const operation& actv_func3) const; const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run)
{
std::vector<std::vector<argument>> results;
// process argu lists
auto iter_num = args.at(0).at<int64_t>();
auto cond = args.at(1).at<bool>();
auto input_num = (args.size() - 2) / 2;
auto dep_num = input_num - 2;
module_ref mod = mods.at(0);
auto param_name_shapes = mod->get_parameter_shapes();
auto param_names = mod->get_parameter_names();
std::vector<argument> dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num);
std::vector<argument> dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1);
auto ins_outputs = args.back().get_sub_objects();
dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num);
std::array<std::vector<argument>, 2> loop_carry_deps = {dep0, dep1};
// loop iter argument
std::vector<argument> in_args = {args.at(input_num), dep1.at(0)};
in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num);
std::vector<argument> out_args = dep0;
out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end());
std::vector<argument> scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end());
auto out_param_indices = model.get_output_params(*mod);
int64_t iter = 0;
for(iter = 0; iter < iter_num and cond; ++iter)
{
// copy iter num and cond to device memory
model.copy(ctx, iter, in_args.at(0));
model.copy(ctx, cond, in_args.at(1));
// wrap up the inputs and outputs
std::unordered_map<std::string, argument> params;
int input_index = 0;
for(const auto& name : param_names)
{
auto ps = mod->get_parameter_shape(name);
if(ps == shape{})
{
continue;
}
// it is an input parameter
if(not contains(out_param_indices, name))
{
params[name] = in_args.at(input_index++);
}
else
{
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
}
else
{
params[name] = out_args.at(output_index);
}
}
}
auto mod_args = run(mod, params);
// copy back cond to be used next iteration
model.copy(ctx, mod_args.at(0), cond);
// mod outputs are used as next loop input
std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1);
const auto& dep_out = loop_carry_deps[(iter + 1) % 2];
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());
std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
}
out_args.erase(out_args.begin());
std::copy(in_args.begin() + 2, in_args.end(), out_args.begin());
model.set_zero(ctx, scan_outputs, iter);
return {out_args};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
/** /**
* Schedule instructions for concurrent execution * Schedule instructions for concurrent execution
...@@ -19,7 +19,7 @@ struct schedule ...@@ -19,7 +19,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(program& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
struct operation; struct operation;
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -26,30 +26,35 @@ struct schedule_model ...@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct schedule_model struct schedule_model
* { {
* std::size_t concurrency() const; //
* void sched(program& p,instruction_ref ins,std::size_t n) const; std::size_t concurrency() const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const; //
* void record(program& p,instruction_ref ins,std::size_t wait_id) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
* std::size_t weight(const operation& op) const; //
* }; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
* //
*/ void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model struct schedule_model
{ {
...@@ -69,11 +74,17 @@ struct schedule_model ...@@ -69,11 +74,17 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
schedule_model& operator=(PrivateDetailTypeErasedT value) schedule_model& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) using std::swap;
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
else if(!private_detail_te_handle_mem_var) if(derived and private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>( {
std::forward<PrivateDetailTypeErasedT>(value)); *derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
schedule_model rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this; return *this;
} }
...@@ -81,7 +92,7 @@ struct schedule_model ...@@ -81,7 +92,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast() PrivateDetailTypeErasedT* any_cast()
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type< ? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -92,7 +103,7 @@ struct schedule_model ...@@ -92,7 +103,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type< ? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -114,22 +125,22 @@ struct schedule_model ...@@ -114,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency(); return (*this).private_detail_te_get_handle().concurrency();
} }
void sched(program& p, instruction_ref ins, std::size_t n) const void sched(module& m, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n); (*this).private_detail_te_get_handle().sched(m, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id); (*this).private_detail_te_get_handle().wait(m, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const void record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id); (*this).private_detail_te_get_handle().record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const std::size_t weight(const operation& op) const
...@@ -152,11 +163,11 @@ struct schedule_model ...@@ -152,11 +163,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0; virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0; virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0; virtual std::size_t weight(const operation& op) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -189,22 +200,22 @@ struct schedule_model ...@@ -189,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override void sched(module& m, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.sched(p, ins, n); private_detail_te_value.sched(m, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.wait(p, ins, wait_id); private_detail_te_value.wait(m, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const override void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.record(p, ins, wait_id); private_detail_te_value.record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const override std::size_t weight(const operation& op) const override
...@@ -277,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x) ...@@ -277,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Avoid implicit conversion with ADL lookup
template <class T>
void migraphx_to_value(value&, const T&) = delete;
template <class T>
value to_value(const T& x);
template <class T>
void from_value(const value& v, T& x);
template <class T>
T from_value(const value& v)
{
T x{};
from_value(v, x);
return x;
}
namespace detail {
template <class T, MIGRAPHX_REQUIRES(std::is_empty<T>{})>
value to_value_impl(rank<0>, const T&)
{
return value::object{};
}
template <class T, class U>
value to_value_impl(rank<1>, const std::pair<T, U>& x)
{
return {x.first, x.second};
}
template <class T>
auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
{
value result = value::array{};
for(auto&& y : x)
{
result.insert(to_value(y));
}
return result;
}
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
value to_value_impl(rank<3>, const T& x)
{
value result = value::object{};
reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); });
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x)
{
return std::int64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x)
{
return std::uint64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x)
{
return double{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x)
{
return x;
}
inline value to_value_impl(rank<8>, const std::string& x) { return x; }
template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
{
return migraphx_to_value(x);
}
template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{
value v;
migraphx_to_value(v, x);
return v;
}
template <class T, MIGRAPHX_REQUIRES(std::is_empty<T>{})>
void from_value_impl(rank<0>, const value& v, T& x)
{
if(not v.is_object())
MIGRAPHX_THROW("Expected an object");
if(not v.get_object().empty())
MIGRAPHX_THROW("Expected an empty object");
x = T{};
}
template <class T>
auto from_value_impl(rank<1>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
for(auto&& e : v)
x.insert(x.end(), from_value<typename T::value_type>(e));
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
if(v.is_binary())
{
for(auto&& e : v.get_binary())
x.insert(x.end(), e);
}
else
{
for(auto&& e : v)
x.insert(x.end(), from_value<typename T::value_type>(e));
}
}
template <class T>
auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{
x.clear();
for(auto&& e : v)
x.emplace(e.get_key(), from_value<typename T::mapped_type>(e));
}
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<4>, const value& v, T& x)
{
reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>;
if(v.contains(name))
y = from_value<type>(v.at(name).without_key());
});
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<5>, const value& v, T& x)
{
x = v.to<T>();
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<6>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{
migraphx_from_value(v, x);
}
} // namespace detail
template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<11>{}, x);
}
template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<9>{}, v, x);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment