Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -33,7 +33,7 @@ struct check_context ...@@ -33,7 +33,7 @@ struct check_context
}; };
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); } void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -23,6 +23,8 @@ struct src_compiler ...@@ -23,6 +23,8 @@ struct src_compiler
std::string compiler = "c++"; std::string compiler = "c++";
std::string flags = ""; std::string flags = "";
std::string output = ""; std::string output = "";
std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr; std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const; std::vector<char> compile(const std::vector<src_file>& srcs) const;
}; };
......
...@@ -30,17 +30,20 @@ struct concat_optimization ...@@ -30,17 +30,20 @@ struct concat_optimization
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct concat_optimization struct concat_optimization
* { {
* std::string name() const; //
* std::string allocate() const; std::string name() const;
* op::concat get_concat(const operation& op) const; //
* }; std::string allocate() const;
* //
*/ op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization struct concat_optimization
{ {
...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x) ...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&) ...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{ {
} }
/* template <class T>
* Type-erased interface for: any_ptr get_queue_context(T&)
* {
* struct context return {};
* { }
* value to_value() const;
* void from_value(const value& v) ; #ifdef TYPE_ERASED_DECLARATION
* void finish() const;
* }; // Type-erased interface for:
* struct context
*/ {
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
//
void finish() const;
};
#else
struct context struct context
{ {
...@@ -124,6 +136,12 @@ struct context ...@@ -124,6 +136,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v); (*this).private_detail_te_get_handle().from_value(v);
} }
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -145,6 +163,7 @@ struct context ...@@ -145,6 +163,7 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -176,6 +195,19 @@ struct context ...@@ -176,6 +195,19 @@ struct context
from_value_context(private_detail_te_self, v); from_value_context(private_detail_te_self, v);
} }
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -216,6 +248,12 @@ struct context ...@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v); private_detail_te_default_from_value(char(0), private_detail_te_value, v);
} }
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
...@@ -68,6 +68,8 @@ struct cpp_generator ...@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code); void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
...@@ -19,7 +19,7 @@ struct eliminate_allocation ...@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{}; std::string allocation_op{};
std::size_t alignment = 32; std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression struct eliminate_common_subexpression
{ {
std::string name() const { return "eliminate_common_subexpression"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct eliminate_concat ...@@ -18,7 +18,7 @@ struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -17,7 +17,7 @@ struct eliminate_contiguous ...@@ -17,7 +17,7 @@ struct eliminate_contiguous
{ {
std::string op_name; std::string op_name;
std::string name() const { return "eliminate_contiguous"; } std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct module; ...@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity struct eliminate_identity
{ {
std::string name() const { return "eliminate_identity"; } std::string name() const { return "eliminate_identity"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L #if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1 #define MIGRAPHX_HAS_FILESYSTEM 1
#else #else
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
namespace migraphx { namespace migraphx {
...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha ...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
...@@ -88,16 +88,16 @@ struct xorshift_generator ...@@ -88,16 +88,16 @@ struct xorshift_generator
template <class T> template <class T>
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed}); std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed});
return result; return result;
} }
template <class T> template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), [=] { return value; }); std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result; return result;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val); std::string to_json_string(const value& val);
value from_json_string(const std::string& str); value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size); value from_json_string(const char* str, std::size_t size);
......
...@@ -9,7 +9,19 @@ namespace migraphx { ...@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name); operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v); operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct marker struct marker
* { {
* void mark_start(instruction_ref ins_ref) ; //
* void mark_start(const program& prog) ; void mark_start(instruction_ref ins_ref);
* void mark_stop(instruction_ref ins) ; //
* void mark_stop(const program& prog) ; void mark_start(const program& prog);
* }; //
* void mark_stop(instruction_ref ins);
*/ //
void mark_stop(const program& prog);
};
#else
struct marker struct marker
{ {
...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x) ...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -101,8 +101,8 @@ template <class M> ...@@ -101,8 +101,8 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) [=, name = std::move(name)](matcher_context& ctx,
->optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result) if(result)
{ {
...@@ -156,6 +156,19 @@ struct id_matcher ...@@ -156,6 +156,19 @@ struct id_matcher
} }
}; };
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
struct basic_matcher struct basic_matcher
...@@ -167,7 +180,7 @@ struct basic_matcher ...@@ -167,7 +180,7 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base ...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result struct matcher_result
{ {
std::unordered_map<std::string, instruction_ref> instructions; struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result; instruction_ref result;
}; };
...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) ...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
result.result = ins; result.result = ins;
result.instructions = ctx.instructions; result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
} }
else else
{ {
...@@ -533,10 +579,22 @@ auto skip_output(Ms... ms) ...@@ -533,10 +579,22 @@ auto skip_output(Ms... ms)
}); });
} }
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; }); [=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
} }
inline auto name_contains(const std::string& name) inline auto name_contains(const std::string& name)
...@@ -547,7 +605,7 @@ inline auto name_contains(const std::string& name) ...@@ -547,7 +605,7 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) { return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0; return names.count(ins->name()) > 0;
}); });
} }
...@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms) ...@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...); return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
} }
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, float tolerance = 1e-6)
{ {
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) { return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal") if(ins->name() != "@literal")
return false; return false;
auto l = ins->get_literal(); auto l = ins->get_literal();
......
...@@ -17,7 +17,7 @@ struct memory_coloring ...@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -36,7 +36,6 @@ struct as_shape ...@@ -36,7 +36,6 @@ struct as_shape
{ {
return args.front().reshape(output_shape); return args.front().reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -67,7 +67,6 @@ struct broadcast ...@@ -67,7 +67,6 @@ struct broadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment