Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -7,7 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Replace instructions which take all literals with a literal of the computation.
......@@ -15,7 +15,7 @@ struct program;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog,
const target& t,
const std::vector<program::parameter_map>& calibration,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* quantize a program to fp16
*/
struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,8 +2,13 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
#include <vector>
#include <initializer_list>
#include <migraphx/rank.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -33,6 +38,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x);
}
template <class C, class T>
auto generic_find_at_impl(rank<1>, C&& c, const T& x) -> decltype(c.find(x))
{
return c.find(x);
}
template <class C, class T>
auto generic_find_at_impl(rank<0>, C&& c, const T& x)
{
auto n = std::distance(c.begin(), c.end());
if(x >= n)
return c.end();
return std::next(c.begin(), x);
}
template <class C, class T, class = typename C::mapped_type>
decltype(auto) generic_at_impl(rank<1>, const C&, T&& it)
{
return it->second;
}
template <class C, class T>
decltype(auto) generic_at_impl(rank<0>, const C&, T&& it)
{
return *it;
}
struct empty
{
};
......@@ -45,6 +77,20 @@ auto generic_find(C&& c, const T& x)
return detail::generic_find_impl(rank<2>{}, c, x);
}
template <class C, class T>
decltype(auto) at(C&& c, const T& x, const std::string& msg = "")
{
auto it = detail::generic_find_at_impl(rank<2>{}, c, x);
if(it == c.end())
{
if(msg.empty())
MIGRAPHX_THROW("At operator out of range for " + get_type_name(c));
else
MIGRAPHX_THROW(msg);
}
return detail::generic_at_impl(rank<2>{}, c, it);
}
template <class C, class T>
bool contains(const C& c, const T& x)
{
......@@ -123,12 +169,41 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it);
}
template <class Range, class Iterator, class F>
void transform(Range&& r, Iterator it, F f)
{
std::transform(r.begin(), r.end(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
}
template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x)
{
std::replace(r.begin(), r.end(), old, new_x);
}
template <class R1, class R2>
bool equal(R1&& r1, R2&& r2)
{
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end());
}
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
template <class Range, class Predicate>
std::vector<range_value<Range>> find_all(Range&& r, Predicate p)
{
std::vector<range_value<Range>> result;
std::copy_if(r.begin(), r.end(), std::back_inserter(result), p);
return result;
}
template <class Iterator>
struct iterator_range
{
......@@ -140,12 +215,18 @@ struct iterator_range
Iterator end() const { return last; }
};
template <class Iterator>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
iterator_range<Iterator> range(Iterator start, Iterator last)
{
return {start, last};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t start, std::ptrdiff_t last)
{
return {{start, {}}, {last, {}}};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t last) { return range(0, last); }
template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{
......
......@@ -5,6 +5,7 @@
#include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -28,7 +29,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d)
{
if(not d.empty())
d.visit([&](auto x) { os << x; });
d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os;
}
......@@ -44,9 +53,19 @@ struct raw_data : raw_data_base
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); });
}
template <class Visitor, class TupleVisitor>
void visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); },
[&] { tv(derived.get_sub_objects()); });
}
/**
......@@ -59,12 +78,7 @@ struct raw_data : raw_data_base
template <class Visitor>
void visit(Visitor v) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
}
/// Returns true if the raw data is only one element
......@@ -141,50 +155,41 @@ struct raw_data : raw_data_base
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraphx::shape::get_type<T>{});
assert(static_cast<const Derived&>(*this).get_shape().type() ==
migraphx::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
};
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
std::string to_string() const
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
std::stringstream ss;
ss << static_cast<const Derived&>(*this);
return ss.str();
}
return result;
};
namespace detail {
template <class V1, class V2, class... Ts>
void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs)
{
s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); },
[&] { v2(xs.get_sub_objects()...); });
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
template <class V1, class V2, class... Ts>
auto visit_all_pack(const shape& s, V1&& v1, V2&& v2)
{
return !(x == y);
return [&](auto&&... xs) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_all_flatten(s, v1, v2, xs...);
};
}
namespace detail {
template <class V, class... Ts>
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
template <class V1, class... Ts>
auto visit_all_pack(const shape& s, V1&& v1)
{
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); });
}
} // namespace detail
......@@ -206,10 +211,7 @@ auto visit_all(T&& x, Ts&&... xs)
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
}
template <class T>
......@@ -231,6 +233,34 @@ auto visit_all(const std::vector<T>& x)
};
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
return result;
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<shape> reduce_dims(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -102,6 +102,29 @@ void reflect_each(T& x, F f)
});
}
template <class T>
struct reflect_equality
{
friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); }
friend bool operator!=(const T& x, const T& y) { return !(x == y); }
};
template <class T>
struct reflect_stream
{
template <class Stream>
friend Stream& operator<<(Stream& os, const T& x)
{
char d = '{';
reflect_each(x, [&](const auto& y, const auto& name) {
os << d << name << "=" << y;
d = ',';
});
os << "}";
return os;
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_op(const operation& op);
operation load_op(const std::string& name);
bool has_op(const std::string& name);
std::vector<std::string> get_operators();
template <class T>
void register_op()
{
register_op(T{});
}
struct register_op_action
{
template <class T>
static void apply()
{
register_op<T>();
}
};
template <class T>
using auto_register_op = auto_register<register_op_action, T>;
#define MIGRAPHX_REGISTER_OP(...) MIGRAPHX_AUTO_REGISTER(register_op_action, __VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#include <migraphx/config.hpp>
#include <migraphx/target.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_target(const target& t);
target make_target(const std::string& name);
std::vector<std::string> get_targets();
template <class T>
void register_target()
{
register_target(T{});
}
struct register_target_action
{
template <class T>
static void apply()
{
register_target<T>();
}
};
template <class T>
using auto_register_target = auto_register<register_target_action, T>;
#define MIGRAPHX_REGISTER_TARGET(...) MIGRAPHX_AUTO_REGISTER(register_target_action, __VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP
#define MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP
#include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct replace_allocate
{
allocation_model model;
bool offload_copy = false;
std::string name() const { return "replace_allocate"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -22,12 +22,14 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void
#define MIGRAPHX_CLASS_REQUIRES(...) void
#else
#define MIGRAPHX_REQUIRES(...) \
long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0
#define MIGRAPHX_CLASS_REQUIRES(...) typename std::enable_if<(migraphx::and_<__VA_ARGS__>{})>::type
#endif
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,7 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Rewrite batchnorm to a multiply and add.
......@@ -16,7 +16,7 @@ struct program;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -7,7 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Rewrite pooling to reduce_mean
......@@ -15,7 +15,7 @@ struct program;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -6,11 +6,12 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Rewrite rnn to gemm and add.
......@@ -18,26 +19,22 @@ struct program;
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
void apply(module& m) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
module& m,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -47,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -57,6 +54,27 @@ struct rewrite_rnn
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run)
{
std::vector<std::vector<argument>> results;
// process argu lists
auto iter_num = args.at(0).at<int64_t>();
auto cond = args.at(1).at<bool>();
auto input_num = (args.size() - 2) / 2;
auto dep_num = input_num - 2;
module_ref mod = mods.at(0);
auto param_name_shapes = mod->get_parameter_shapes();
auto param_names = mod->get_parameter_names();
std::vector<argument> dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num);
std::vector<argument> dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1);
auto ins_outputs = args.back().get_sub_objects();
dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num);
std::array<std::vector<argument>, 2> loop_carry_deps = {dep0, dep1};
// loop iter argument
std::vector<argument> in_args = {args.at(input_num), dep1.at(0)};
in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num);
std::vector<argument> out_args = dep0;
out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end());
std::vector<argument> scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end());
auto out_param_indices = model.get_output_params(*mod);
int64_t iter = 0;
for(iter = 0; iter < iter_num and cond; ++iter)
{
// copy iter num and cond to device memory
model.copy(ctx, iter, in_args.at(0));
model.copy(ctx, cond, in_args.at(1));
// wrap up the inputs and outputs
std::unordered_map<std::string, argument> params;
int input_index = 0;
for(const auto& name : param_names)
{
auto ps = mod->get_parameter_shape(name);
if(ps == shape{})
{
continue;
}
// it is an input parameter
if(not contains(out_param_indices, name))
{
params[name] = in_args.at(input_index++);
}
else
{
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
}
else
{
params[name] = out_args.at(output_index);
}
}
}
auto mod_args = run(mod, params);
// copy back cond to be used next iteration
model.copy(ctx, mod_args.at(0), cond);
// mod outputs are used as next loop input
std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1);
const auto& dep_out = loop_carry_deps[(iter + 1) % 2];
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());
std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
}
out_args.erase(out_args.begin());
std::copy(in_args.begin() + 2, in_args.end(), out_args.begin());
model.set_zero(ctx, scan_outputs, iter);
return {out_args};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -9,7 +9,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Schedule instructions for concurrent execution
......@@ -19,7 +19,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct operation;
#ifdef DOXYGEN
......@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct schedule_model
{
//
std::size_t concurrency() const;
//
void sched(module& m, instruction_ref ins, std::size_t n) const;
//
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
//
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model
{
......@@ -69,11 +74,17 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT>
schedule_model& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
schedule_model rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
......@@ -81,7 +92,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -92,7 +103,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -114,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency();
}
void sched(program& p, instruction_ref ins, std::size_t n) const
void sched(module& m, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n);
(*this).private_detail_te_get_handle().sched(m, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const
void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
(*this).private_detail_te_get_handle().wait(m, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const
void record(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
(*this).private_detail_te_get_handle().record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const
......@@ -152,11 +163,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -189,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override
void sched(module& m, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.sched(p, ins, n);
private_detail_te_value.sched(m, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override
void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.wait(p, ins, wait_id);
private_detail_te_value.wait(m, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const override
void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.record(p, ins, wait_id);
private_detail_te_value.record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const override
......@@ -277,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Avoid implicit conversion with ADL lookup
template <class T>
void migraphx_to_value(value&, const T&) = delete;
template <class T>
value to_value(const T& x);
template <class T>
void from_value(const value& v, T& x);
template <class T>
T from_value(const value& v)
{
T x{};
from_value(v, x);
return x;
}
namespace detail {
template <class T, MIGRAPHX_REQUIRES(std::is_empty<T>{})>
value to_value_impl(rank<0>, const T&)
{
return value::object{};
}
template <class T, class U>
value to_value_impl(rank<1>, const std::pair<T, U>& x)
{
return {x.first, x.second};
}
template <class T>
auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
{
value result = value::array{};
for(auto&& y : x)
{
result.insert(to_value(y));
}
return result;
}
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
value to_value_impl(rank<3>, const T& x)
{
value result = value::object{};
reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); });
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x)
{
return std::int64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x)
{
return std::uint64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x)
{
return double{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x)
{
return x;
}
inline value to_value_impl(rank<8>, const std::string& x) { return x; }
template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
{
return migraphx_to_value(x);
}
template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{
value v;
migraphx_to_value(v, x);
return v;
}
template <class T, MIGRAPHX_REQUIRES(std::is_empty<T>{})>
void from_value_impl(rank<0>, const value& v, T& x)
{
if(not v.is_object())
MIGRAPHX_THROW("Expected an object");
if(not v.get_object().empty())
MIGRAPHX_THROW("Expected an empty object");
x = T{};
}
template <class T>
auto from_value_impl(rank<1>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
for(auto&& e : v)
x.insert(x.end(), from_value<typename T::value_type>(e));
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
if(v.is_binary())
{
for(auto&& e : v.get_binary())
x.insert(x.end(), e);
}
else
{
for(auto&& e : v)
x.insert(x.end(), from_value<typename T::value_type>(e));
}
}
template <class T>
auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{
x.clear();
for(auto&& e : v)
x.emplace(e.get_key(), from_value<typename T::mapped_type>(e));
}
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<4>, const value& v, T& x)
{
reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>;
if(v.contains(name))
y = from_value<type>(v.at(name).without_key());
});
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<5>, const value& v, T& x)
{
x = v.to<T>();
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<6>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{
migraphx_from_value(v, x);
}
} // namespace detail
template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<11>{}, x);
}
template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<9>{}, v, x);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment