Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -32,18 +32,22 @@ struct allocation_model ...@@ -32,18 +32,22 @@ struct allocation_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct allocation_model struct allocation_model
* { {
* std::string name() const; //
* std::string copy() const; std::string name() const;
* operation allocate(const shape& s) const; //
* operation preallocate(const shape& s,std::string id) const; std::string copy() const;
* }; //
* operation allocate(const shape& s) const;
*/ //
operation preallocate(const shape& s, std::string id) const;
};
#else
struct allocation_model struct allocation_model
{ {
...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x) ...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct any_ptr
{
any_ptr() = default;
template <class T>
any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name<T*>())
{
}
any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {}
void* get(std::string_view n) const
{
if(name != n)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} +
" != " + std::string{n});
return ptr;
}
template <class T>
T get() const
{
static_assert(std::is_pointer<T>{}, "Must be a pointer");
assert(ptr != nullptr);
if(ti and std::type_index{typeid(T)} != *ti)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
else if(name != get_name<T>())
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
return reinterpret_cast<T>(ptr);
}
void* unsafe_get() const { return ptr; }
private:
void* ptr = nullptr;
optional<std::type_index> ti = nullopt;
std::string_view name = "";
template <class T>
static const std::string& get_name()
{
return get_type_name<std::remove_cv_t<std::remove_pointer_t<T>>>();
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
...@@ -23,6 +23,7 @@ struct src_compiler ...@@ -23,6 +23,7 @@ struct src_compiler
std::string compiler = "c++"; std::string compiler = "c++";
std::string flags = ""; std::string flags = "";
std::string output = ""; std::string output = "";
std::string launcher = "";
std::function<fs::path(fs::path)> process = nullptr; std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const; std::vector<char> compile(const std::vector<src_file>& srcs) const;
}; };
......
...@@ -30,17 +30,20 @@ struct concat_optimization ...@@ -30,17 +30,20 @@ struct concat_optimization
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct concat_optimization struct concat_optimization
* { {
* std::string name() const; //
* std::string allocate() const; std::string name() const;
* op::concat get_concat(const operation& op) const; //
* }; std::string allocate() const;
* //
*/ op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization struct concat_optimization
{ {
...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x) ...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&) ...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{ {
} }
/* template <class T>
* Type-erased interface for: any_ptr get_queue_context(T&)
* {
* struct context return {};
* { }
* value to_value() const;
* void from_value(const value& v) ; #ifdef TYPE_ERASED_DECLARATION
* void finish() const;
* }; // Type-erased interface for:
* struct context
*/ {
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
//
void finish() const;
};
#else
struct context struct context
{ {
...@@ -124,6 +136,12 @@ struct context ...@@ -124,6 +136,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v); (*this).private_detail_te_get_handle().from_value(v);
} }
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -145,6 +163,7 @@ struct context ...@@ -145,6 +163,7 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -176,6 +195,19 @@ struct context ...@@ -176,6 +195,19 @@ struct context
from_value_context(private_detail_te_self, v); from_value_context(private_detail_te_self, v);
} }
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -216,6 +248,12 @@ struct context ...@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v); private_detail_te_default_from_value(char(0), private_detail_te_value, v);
} }
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
...@@ -68,6 +68,8 @@ struct cpp_generator ...@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code); void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val); std::string to_json_string(const value& val);
value from_json_string(const std::string& str); value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size); value from_json_string(const char* str, std::size_t size);
......
...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct marker struct marker
* { {
* void mark_start(instruction_ref ins_ref) ; //
* void mark_start(const program& prog) ; void mark_start(instruction_ref ins_ref);
* void mark_stop(instruction_ref ins) ; //
* void mark_stop(const program& prog) ; void mark_start(const program& prog);
* }; //
* void mark_stop(instruction_ref ins);
*/ //
void mark_stop(const program& prog);
};
#else
struct marker struct marker
{ {
...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x) ...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -101,8 +101,8 @@ template <class M> ...@@ -101,8 +101,8 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) [=, name = std::move(name)](matcher_context& ctx,
->optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result) if(result)
{ {
...@@ -536,7 +536,7 @@ auto skip_output(Ms... ms) ...@@ -536,7 +536,7 @@ auto skip_output(Ms... ms)
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; }); [=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
} }
inline auto name_contains(const std::string& name) inline auto name_contains(const std::string& name)
...@@ -547,7 +547,7 @@ inline auto name_contains(const std::string& name) ...@@ -547,7 +547,7 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) { return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0; return names.count(ins->name()) > 0;
}); });
} }
......
...@@ -36,7 +36,6 @@ struct as_shape ...@@ -36,7 +36,6 @@ struct as_shape
{ {
return args.front().reshape(output_shape); return args.front().reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -67,7 +67,6 @@ struct broadcast ...@@ -67,7 +67,6 @@ struct broadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
...@@ -15,6 +17,14 @@ enum padding_mode_t ...@@ -15,6 +17,14 @@ enum padding_mode_t
valid valid
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max
};
// indicate rnn computation direction // indicate rnn computation direction
enum class rnn_direction enum class rnn_direction
{ {
...@@ -23,6 +33,7 @@ enum class rnn_direction ...@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
......
...@@ -97,7 +97,6 @@ struct deconvolution ...@@ -97,7 +97,6 @@ struct deconvolution
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) { par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0]; const int w = idx_win[0];
...@@ -140,9 +139,7 @@ struct deconvolution ...@@ -140,9 +139,7 @@ struct deconvolution
weights(idx_wei.begin(), idx_wei.end()); weights(idx_wei.begin(), idx_wei.end());
} }
}); });
}); });
}); });
return result; return result;
} }
......
...@@ -51,7 +51,6 @@ struct flatten ...@@ -51,7 +51,6 @@ struct flatten
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -69,7 +69,6 @@ struct multibroadcast ...@@ -69,7 +69,6 @@ struct multibroadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -181,7 +181,8 @@ struct nonmaxsuppression ...@@ -181,7 +181,8 @@ struct nonmaxsuppression
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); }); make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0; int64_t box_idx = 0;
transform_if(scores.begin() + score_offset, transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num, scores.begin() + score_offset + box_num,
insert_to_sorted_boxes, insert_to_sorted_boxes,
[&](auto sc) { [&](auto sc) {
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct pooling struct pooling
{ {
std::string mode = "average"; pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0}; std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1}; std::vector<std::size_t> lengths = {1, 1};
......
...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived> ...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.at(0); auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
} }
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
{ {
argument result = args[0].copy(); result = args[0].copy();
auto s = result.get_shape(); }
else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens(); auto lens = s.lens();
lens[axis] = 1; lens[axis] = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment