Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
......@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
return;
}
......
......@@ -32,18 +32,22 @@ struct allocation_model
#else
/*
* Type-erased interface for:
*
* struct allocation_model
* {
* std::string name() const;
* std::string copy() const;
* operation allocate(const shape& s) const;
* operation preallocate(const shape& s,std::string id) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct allocation_model
{
//
std::string name() const;
//
std::string copy() const;
//
operation allocate(const shape& s) const;
//
operation preallocate(const shape& s, std::string id) const;
};
#else
struct allocation_model
{
......@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct any_ptr
{
any_ptr() = default;
template <class T>
any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name<T*>())
{
}
any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {}
void* get(std::string_view n) const
{
if(name != n)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} +
" != " + std::string{n});
return ptr;
}
template <class T>
T get() const
{
static_assert(std::is_pointer<T>{}, "Must be a pointer");
assert(ptr != nullptr);
if(ti and std::type_index{typeid(T)} != *ti)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
else if(name != get_name<T>())
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
return reinterpret_cast<T>(ptr);
}
void* unsafe_get() const { return ptr; }
private:
void* ptr = nullptr;
optional<std::type_index> ti = nullopt;
std::string_view name = "";
template <class T>
static const std::string& get_name()
{
return get_type_name<std::remove_cv_t<std::remove_pointer_t<T>>>();
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
......@@ -23,6 +23,7 @@ struct src_compiler
std::string compiler = "c++";
std::string flags = "";
std::string output = "";
std::string launcher = "";
std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const;
};
......
......@@ -30,17 +30,20 @@ struct concat_optimization
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct concat_optimization
{
//
std::string name() const;
//
std::string allocate() const;
//
op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization
{
......@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{
}
/*
* Type-erased interface for:
*
* struct context
* {
* value to_value() const;
* void from_value(const value& v) ;
* void finish() const;
* };
*
*/
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct context
{
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
//
void finish() const;
};
#else
struct context
{
......@@ -124,6 +136,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v);
}
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -145,6 +163,7 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0;
};
......@@ -176,6 +195,19 @@ struct context
from_value_context(private_detail_te_self, v);
}
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
......@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast();
return *y;
}
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
......@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
......@@ -8,6 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val);
value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size);
......
......@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else
/*
* Type-erased interface for:
*
* struct marker
* {
* void mark_start(instruction_ref ins_ref) ;
* void mark_start(const program& prog) ;
* void mark_stop(instruction_ref ins) ;
* void mark_stop(const program& prog) ;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct marker
{
//
void mark_start(instruction_ref ins_ref);
//
void mark_start(const program& prog);
//
void mark_stop(instruction_ref ins);
//
void mark_stop(const program& prog);
};
#else
struct marker
{
......@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -101,17 +101,17 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
[=, name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
/// Convert a matcher to a bindable matcher
......@@ -536,7 +536,7 @@ auto skip_output(Ms... ms)
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
}
inline auto name_contains(const std::string& name)
......@@ -547,7 +547,7 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
......
......@@ -36,7 +36,6 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -67,7 +67,6 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <utility>
......@@ -15,6 +17,14 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max
};
// indicate rnn computation direction
enum class rnn_direction
{
......@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
......
......@@ -97,7 +97,6 @@ struct deconvolution
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
......@@ -140,9 +139,7 @@ struct deconvolution
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
......
......@@ -51,7 +51,6 @@ struct flatten
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -69,7 +69,6 @@ struct multibroadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -181,14 +181,15 @@ struct nonmaxsuppression
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
......
......@@ -16,12 +16,13 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pooling
{
std::string mode = "average";
pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1};
......
......@@ -38,18 +38,38 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.at(0);
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape&, std::vector<argument> args) const
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result = args[0].copy();
auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
{
result = args[0].copy();
}
else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
result.visit([&](auto output) {
using type = decltype(output);
par_for(batch.elements(), [&](auto i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment