Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
......@@ -33,7 +33,7 @@ struct check_context
};
std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); }
void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -23,6 +23,8 @@ struct src_compiler
std::string compiler = "c++";
std::string flags = "";
std::string output = "";
std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const;
};
......
......@@ -30,17 +30,20 @@ struct concat_optimization
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct concat_optimization
{
//
std::string name() const;
//
std::string allocate() const;
//
op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization
{
......@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{
}
/*
* Type-erased interface for:
*
* struct context
* {
* value to_value() const;
* void from_value(const value& v) ;
* void finish() const;
* };
*
*/
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct context
{
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
//
void finish() const;
};
#else
struct context
{
......@@ -124,6 +136,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v);
}
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -145,6 +163,7 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0;
};
......@@ -176,6 +195,19 @@ struct context
from_value_context(private_detail_te_self, v);
}
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
......@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast();
return *y;
}
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
......@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
......@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression
{
std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std::string op_name;
std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
......
......@@ -3,7 +3,7 @@
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
......@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
......
......@@ -88,16 +88,16 @@ struct xorshift_generator
template <class T>
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{
auto result = make_shared_array<T>(s.elements());
std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed});
auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed});
return result;
}
template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{
auto result = make_shared_array<T>(s.elements());
std::generate(result.get(), result.get() + s.elements(), [=] { return value; });
auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result;
}
......
......@@ -8,6 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val);
value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size);
......
......@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else
/*
* Type-erased interface for:
*
* struct marker
* {
* void mark_start(instruction_ref ins_ref) ;
* void mark_start(const program& prog) ;
* void mark_stop(instruction_ref ins) ;
* void mark_stop(const program& prog) ;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct marker
{
//
void mark_start(instruction_ref ins_ref);
//
void mark_start(const program& prog);
//
void mark_stop(instruction_ref ins);
//
void mark_stop(const program& prog);
};
#else
struct marker
{
......@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -101,17 +101,17 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
[=, name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
/// Convert a matcher to a bindable matcher
......@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
......@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
......@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
......@@ -533,10 +579,22 @@ auto skip_output(Ms... ms)
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
}
inline auto name_contains(const std::string& name)
......@@ -547,7 +605,7 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
......@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
return false;
auto l = ins->get_literal();
......
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -36,7 +36,6 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -67,7 +67,6 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment