Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -2,13 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -23,7 +21,7 @@ struct transpose
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
return pack(f(self.dims, "permutation"));
}
std::string name() const { return "transpose"; }
......@@ -34,6 +32,7 @@ struct transpose
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size())
{
MIGRAPHX_THROW("Permutation has wrong number of axes");
......@@ -42,7 +41,7 @@ struct transpose
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPHX_THROW("Invalid permutation");
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
......@@ -55,7 +54,7 @@ struct transpose
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -2,6 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,52 +15,57 @@ namespace op {
template <class Derived>
struct unary : op_name<Derived>
{
std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return pf + "${0}";
}
else
{
return "${function:" + pf + "}(${0})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0);
if(s.packed())
if(s.scalar())
{
return s;
}
else
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto in_shape = args[0].get_shape();
if(in_shape.packed())
{
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
}
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
});
return result;
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct unary_not : unary<unary_not>
{
std::string point_function() const { return "!"; }
auto apply() const
{
return [](auto x) { return not x; };
}
std::string name() const { return "not"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct not_computable
{
argument compute(const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("not computable");
}
};
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
struct unknown
{
std::string op;
......
......@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -26,36 +26,60 @@ struct unsqueeze
return pack(f(self.axes, "axes"));
}
value attributes() const
{
value normalize;
normalize["axes"] =
value::array{normalize_attribute::include_min, normalize_attribute::use_output};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard_or_scalar();
check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
return shape{type, old_lens};
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++)
for(auto i : range(new_size))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
}
else
{
new_lens[i] = old_lens[p++];
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
}
return shape{type, new_lens};
return shape{type, new_lens, new_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -7,10 +7,15 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -57,6 +62,8 @@ struct operation
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
......@@ -96,7 +103,73 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
auto compute_op(rank<2>,
auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, {});
}
template <class T>
shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return compute_shape_op(rank<3>{}, x, inputs);
}
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape mod_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
if(mod_args.empty())
return compute_shape_op(x, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape mod_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -106,14 +179,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -125,35 +190,132 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, ctx, output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
template <class T, class F>
argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<4>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<3>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return compute_op(rank<2>{}, x, output_shape, input);
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T, class F>
argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -174,6 +336,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
return {};
}
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs), std::true_type{});
template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;
template <class T>
auto need_normalization_op(const T& x)
-> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
return {};
}
template <class T>
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{
......@@ -218,26 +394,113 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
auto compile_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
return x.compile(auto_any_cast(ctx), output_shape, input);
}
template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
return value::object{};
}
template <class T>
value compile_op(const T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
{
return compile_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
value attributes_op(const T&)
{
return value::object{};
}
template <class T>
value to_value_op(const T& x)
{
return migraphx::to_value(x);
}
template <class T>
void from_value_op(T& x, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
return migraphx::from_value(v, x);
}
template <class T>
lifetime get_lifetime_op(const T&)
{
return lifetime::local;
}
} // namespace detail
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* bool has_finalize() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct operation
{
//
std::string name() const;
// (optional)
bool is_context_free() const;
// (optional)
bool need_normalization() const;
// (optional)
bool has_finalize() const;
// (optional)
lifetime get_lifetime() const;
// (optional)
std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
// (optional)
value compile(context& ctx, const shape& output, const std::vector<shape>& input);
// (optional)
void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
// (optional)
shape compute_shape(const std::vector<shape>& input) const;
// (optional)
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const;
// (optional)
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
value attributes() const;
//
friend std::ostream& operator<<(std::ostream& os, const operation& op);
//
friend bool operator==(const operation& x, const operation& y);
};
#else
struct operation
{
......@@ -257,11 +520,17 @@ struct operation
template <typename PrivateDetailTypeErasedT>
operation& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
operation rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
......@@ -269,7 +538,7 @@ struct operation
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -280,7 +549,7 @@ struct operation
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -308,18 +577,36 @@ struct operation
return (*this).private_detail_te_get_handle().is_context_free();
}
bool need_normalization() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().need_normalization();
}
bool has_finalize() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().has_finalize();
}
lifetime get_lifetime() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_lifetime();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compile(ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -332,6 +619,13 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(input);
}
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args);
}
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -344,6 +638,47 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input);
}
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
}
value to_value() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().to_value();
}
void from_value(const value& v)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().from_value(v);
}
value attributes() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().attributes();
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
......@@ -371,16 +706,38 @@ struct operation
virtual std::string name() const = 0;
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
virtual argument
compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
template <class T>
......@@ -396,6 +753,19 @@ struct operation
return detail::is_context_free_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.need_normalization())
{
return private_detail_te_self.need_normalization();
}
template <class T>
static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self)
{
return detail::need_normalization_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.has_finalize())
......@@ -409,6 +779,19 @@ struct operation
return detail::has_finalize_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_lifetime())
{
return private_detail_te_self.get_lifetime();
}
template <class T>
static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{
return detail::get_lifetime_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self,
......@@ -426,6 +809,27 @@ struct operation
return detail::output_alias_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compile(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compile(ctx, output, input))
{
return private_detail_te_self.compile(ctx, output, input);
}
template <class T>
static value private_detail_te_default_compile(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
return detail::compile_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self,
......@@ -447,6 +851,42 @@ struct operation
detail::finalize_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compute_shape(input))
{
return private_detail_te_self.compute_shape(input);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& input)
{
return detail::compute_shape_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(private_detail_te_self.compute_shape(inputs, mod_args))
{
return private_detail_te_self.compute_shape(inputs, mod_args);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return detail::mod_compute_shape_op(private_detail_te_self, inputs, mod_args);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
......@@ -487,6 +927,105 @@ struct operation
return detail::compute_op(private_detail_te_self, output, input);
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.to_value())
{
return private_detail_te_self.to_value();
}
template <class T>
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
{
return detail::to_value_op(private_detail_te_self);
}
template <class T>
static auto
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
-> decltype(private_detail_te_self.from_value(v))
{
private_detail_te_self.from_value(v);
}
template <class T>
static void
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
{
detail::from_value_op(private_detail_te_self, v);
}
template <class T>
static auto private_detail_te_default_attributes(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.attributes())
{
return private_detail_te_self.attributes();
}
template <class T>
static value private_detail_te_default_attributes(float, T&& private_detail_te_self)
{
return detail::attributes_op(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -523,18 +1062,37 @@ struct operation
return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
}
bool need_normalization() const override
{
return private_detail_te_default_need_normalization(char(0), private_detail_te_value);
}
bool has_finalize() const override
{
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
lifetime get_lifetime() const override
{
return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
{
return private_detail_te_default_compile(
char(0), private_detail_te_value, ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
......@@ -545,7 +1103,15 @@ struct operation
shape compute_shape(const std::vector<shape>& input) const override
{
return private_detail_te_value.compute_shape(input);
return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
}
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const override
{
return private_detail_te_default_compute_shape(
char(0), private_detail_te_value, inputs, mod_args);
}
argument compute(context& ctx,
......@@ -564,6 +1130,49 @@ struct operation
char(0), private_detail_te_value, output, input);
}
argument compute(
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
}
value to_value() const override
{
return private_detail_te_default_to_value(char(0), private_detail_te_value);
}
void from_value(const value& v) override
{
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
value attributes() const override
{
return private_detail_te_default_attributes(char(0), private_detail_te_value);
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using migraphx::detail::operation_operators::operator<<;
......@@ -640,9 +1249,72 @@ inline const ValueType& any_cast(const operation& x)
throw std::bad_cast();
return *y;
}
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
dependent_type<context, Context> ctx2 = std::ref(ctx);
return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(op.compile(ctx, ctx, output_shape, input))
{
return op.compile(ctx, ctx, output_shape, input);
}
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
return op.compute_shape(inputs);
}
template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.compute_shape(inputs))
{
return op.compute_shape(inputs);
}
template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs))
{
return detail::compute_shape_op(op, inputs);
}
inline shape compute_shape(const operation& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.compute_shape(inputs, mod_args))
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
......@@ -651,6 +1323,14 @@ bool is_context_free(const T& x)
return detail::is_context_free_op(x);
}
inline bool need_normalization(const operation& op) { return op.need_normalization(); }
template <class T>
bool need_normalization(const T& x)
{
return detail::need_normalization_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
......@@ -659,6 +1339,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x);
}
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/asinh.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
......@@ -23,21 +25,33 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/equal.hpp>
#include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/log.hpp>
#include <migraphx/op/logical_and.hpp>
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
......@@ -45,24 +59,40 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_min.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_prod.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......@@ -72,11 +102,17 @@
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/step.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
struct function_output_iterator
{
F f;
using self = function_output_iterator;
using difference_type = void;
using reference = void;
using value_type = void;
using pointer = void;
using iterator_category = std::output_iterator_tag;
struct output_proxy
{
template <class T>
output_proxy& operator=(const T& value)
{
assert(f);
(*f)(value);
return *this;
}
F* f;
};
output_proxy operator*() { return output_proxy{&f}; }
self& operator++() { return *this; }
self& operator++(int) { return *this; } // NOLINT
};
template <class F>
function_output_iterator<F> make_function_output_iterator(F f)
{
return {std::move(f)};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
......@@ -13,14 +13,25 @@ inline void calculate_padding(int64_t idx,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim)
int64_t weight_dim,
bool is_same_upper = true)
{
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
auto pad_ndims = pads.size() / 2;
if(is_same_upper)
{
pads[idx] = pad / 2;
pads[idx + pad_ndims] = pad - pad / 2;
}
else
{
pads[idx + pad_ndims] = pad / 2;
pads[idx] = pad - pad / 2;
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{
dfor(xs...)(f);
}
};
}
......
......@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
}
......
......@@ -3,16 +3,19 @@
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -22,22 +25,53 @@ struct pass
{
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
};
#else
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct pass
{
//
std::string name() const;
// (optional)
void apply(module_pass_manager& mpm) const;
// (optional)
void apply(program& p) const;
};
#else
struct pass
{
......@@ -57,11 +91,17 @@ struct pass
template <typename PrivateDetailTypeErasedT>
pass& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
pass rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
......@@ -69,7 +109,7 @@ struct pass
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -80,7 +120,7 @@ struct pass
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
......@@ -102,6 +142,12 @@ struct pass
return (*this).private_detail_te_get_handle().name();
}
void apply(module_pass_manager& mpm) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(mpm);
}
void apply(program& p) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -121,10 +167,39 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual std::string name() const = 0;
virtual void apply(module_pass_manager& mpm) const = 0;
virtual void apply(program& p) const = 0;
};
template <class T>
static auto
private_detail_te_default_apply(char, T&& private_detail_te_self, module_pass_manager& mpm)
-> decltype(private_detail_te_self.apply(mpm))
{
private_detail_te_self.apply(mpm);
}
template <class T>
static void
private_detail_te_default_apply(float, T&& private_detail_te_self, module_pass_manager& mpm)
{
migraphx::detail::module_pass_manager_apply(private_detail_te_self, mpm);
}
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, program& p)
-> decltype(private_detail_te_self.apply(p))
{
private_detail_te_self.apply(p);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, program& p)
{
migraphx::nop(private_detail_te_self, p);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -155,7 +230,17 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { private_detail_te_value.apply(p); }
void apply(module_pass_manager& mpm) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, mpm);
}
void apply(program& p) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, p);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
......@@ -221,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
#include <migraphx/pass.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager
{
module_pass_manager() = default;
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual void run_pass(const pass& p) = 0;
protected:
virtual ~module_pass_manager() {}
};
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -20,29 +20,22 @@ inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permu
return result;
}
inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation);
template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
std::stable_sort(
result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
inline std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct preallocate_param
{
std::string param;
allocation_model model;
std::string name() const { return "preallocate_param"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <string>
#include <memory>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct process_impl;
struct process
{
process(const std::string& cmd);
// move constructor
process(process&&) noexcept;
// copy assignment operator
process& operator=(process rhs);
~process() noexcept;
process& cwd(const fs::path& p);
void exec();
private:
std::unique_ptr<process_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
......@@ -4,6 +4,7 @@
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/module.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
......@@ -22,7 +23,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl;
const operation& get_operation(instruction_ref ins);
struct marker;
/**
* @brief Stores the instruction stream
......@@ -42,50 +43,7 @@ struct program
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
}
instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
}
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction_ref add_literal(literal l);
instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......@@ -93,15 +51,11 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const;
argument eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const;
std::vector<argument> eval(parameter_map params) const;
std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
shape get_shape() const;
std::vector<shape> get_output_shapes() const;
context& get_context() const;
......@@ -109,27 +63,57 @@ struct program
void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const;
void finalize();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void
perf_report(std::ostream& os, std::size_t n, parameter_map params, std::size_t batch = 1) const;
void mark(const parameter_map& params, marker&& m);
value to_value() const;
void from_value(const value& v);
void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const;
void print(std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
void annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const;
program& sort();
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
// module related api
module* create_module(const std::string& name);
module* get_module(const std::string& name);
const module* get_module(const std::string& name) const;
module* get_main_module();
const module* get_main_module() const;
std::vector<const module*> get_modules() const;
std::vector<module*> get_modules();
void remove_module(const std::string& name);
void remove_unused_modules();
private:
void assign(const program& p);
std::unique_ptr<program_impl> impl;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment