Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#include "migraphx/make_op.hpp"
#include "migraphx/normalize_attributes.hpp"
#include "migraphx/operation.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta);
template <typename T = float>
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}});
}
template <typename T = float>
instruction_ref add_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP
...@@ -27,29 +27,29 @@ struct argument : raw_data<argument> ...@@ -27,29 +27,29 @@ struct argument : raw_data<argument>
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d) argument(shape s, F d)
: m_shape(std::move(s)), : m_shape(std::move(s))
m_data({[f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }})
{ {
assign_buffer([f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); });
} }
template <class T> template <class T>
argument(shape s, T* d) argument(shape s, T* d)
: m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d); }}) : m_shape(std::move(s))
{ {
assign_buffer([d] { return reinterpret_cast<char*>(d); });
} }
template <class T> template <class T>
argument(shape s, std::shared_ptr<T> d) argument(shape s, std::shared_ptr<T> d)
: m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d.get()); }}) : m_shape(std::move(s))
{ {
assign_buffer([d] { return reinterpret_cast<char*>(d.get()); });
} }
argument(shape s, std::nullptr_t); argument(shape s, std::nullptr_t);
argument(const std::vector<argument>& args); argument(const std::vector<argument>& args);
static argument load(const shape& s, char* buffer);
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
char* data() const; char* data() const;
...@@ -67,7 +67,11 @@ struct argument : raw_data<argument> ...@@ -67,7 +67,11 @@ struct argument : raw_data<argument>
std::vector<argument> get_sub_objects() const; std::vector<argument> get_sub_objects() const;
/// Return the ith element
argument element(std::size_t i) const;
private: private:
void assign_buffer(std::function<char*()> d);
struct data_t struct data_t
{ {
std::function<char*()> get = nullptr; std::function<char*()> get = nullptr;
......
...@@ -13,7 +13,7 @@ struct module; ...@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous struct auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,7 +33,7 @@ struct check_context ...@@ -33,7 +33,7 @@ struct check_context
}; };
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); } void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -23,6 +23,8 @@ struct src_compiler ...@@ -23,6 +23,8 @@ struct src_compiler
std::string compiler = "c++"; std::string compiler = "c++";
std::string flags = ""; std::string flags = "";
std::string output = ""; std::string output = "";
std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr; std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const; std::vector<char> compile(const std::vector<src_file>& srcs) const;
}; };
......
...@@ -30,17 +30,20 @@ struct concat_optimization ...@@ -30,17 +30,20 @@ struct concat_optimization
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct concat_optimization struct concat_optimization
* { {
* std::string name() const; //
* std::string allocate() const; std::string name() const;
* op::concat get_concat(const operation& op) const; //
* }; std::string allocate() const;
* //
*/ op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization struct concat_optimization
{ {
...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x) ...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&) ...@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{ {
} }
/* template <class T>
* Type-erased interface for: any_ptr get_queue_context(T&)
* {
* struct context return {};
* { }
* value to_value() const;
* void from_value(const value& v) ; #ifdef TYPE_ERASED_DECLARATION
* void finish() const;
* }; // Type-erased interface for:
* struct context
*/ {
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
//
void finish() const;
};
#else
struct context struct context
{ {
...@@ -124,6 +136,12 @@ struct context ...@@ -124,6 +136,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v); (*this).private_detail_te_get_handle().from_value(v);
} }
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -145,6 +163,7 @@ struct context ...@@ -145,6 +163,7 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -176,6 +195,19 @@ struct context ...@@ -176,6 +195,19 @@ struct context
from_value_context(private_detail_te_self, v); from_value_context(private_detail_te_self, v);
} }
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -216,6 +248,12 @@ struct context ...@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v); private_detail_te_default_from_value(char(0), private_detail_te_value, v);
} }
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
...@@ -34,6 +34,7 @@ struct cpp_generator ...@@ -34,6 +34,7 @@ struct cpp_generator
std::string return_type = "void"; std::string return_type = "void";
std::string name = ""; std::string name = "";
std::vector<std::string> attributes = {}; std::vector<std::string> attributes = {};
std::vector<std::string> tparams = {};
function& set_body(const module& m, const generate_module_callback& g); function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s) function& set_body(const std::string& s)
{ {
...@@ -52,6 +53,7 @@ struct cpp_generator ...@@ -52,6 +53,7 @@ struct cpp_generator
} }
function& set_types(const module& m); function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
}; };
cpp_generator(); cpp_generator();
...@@ -66,6 +68,10 @@ struct cpp_generator ...@@ -66,6 +68,10 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const; std::string str() const;
......
...@@ -19,7 +19,7 @@ struct eliminate_allocation ...@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{}; std::string allocation_op{};
std::size_t alignment = 32; std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression struct eliminate_common_subexpression
{ {
std::string name() const { return "eliminate_common_subexpression"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct eliminate_concat ...@@ -18,7 +18,7 @@ struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -17,7 +17,7 @@ struct eliminate_contiguous ...@@ -17,7 +17,7 @@ struct eliminate_contiguous
{ {
std::string op_name; std::string op_name;
std::string name() const { return "eliminate_contiguous"; } std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct module; ...@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity struct eliminate_identity
{ {
std::string name() const { return "eliminate_identity"; } std::string name() const { return "eliminate_identity"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -49,12 +49,13 @@ inline std::string make_source_context(const std::string& file, int line, const ...@@ -49,12 +49,13 @@ inline std::string make_source_context(const std::string& file, int line, const
return file + ":" + std::to_string(line) + ": " + fname; return file + ":" + std::to_string(line) + ": " + fname;
} }
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_SOURCE_CTX() migraphx::make_source_context(__FILE__, __LINE__, __func__)
/** /**
* @brief Throw an exception with context information * @brief Throw an exception with context information
*/ */
#define MIGRAPHX_THROW(...) \ #define MIGRAPHX_THROW(...) throw migraphx::make_exception(MIGRAPHX_MAKE_SOURCE_CTX(), __VA_ARGS__)
throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__, __func__), \
__VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L #if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1 #define MIGRAPHX_HAS_FILESYSTEM 1
#else #else
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct fuse_pointwise
{
std::string name() const { return "fuse_pointwise"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
namespace migraphx { namespace migraphx {
...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha ...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
...@@ -88,16 +88,16 @@ struct xorshift_generator ...@@ -88,16 +88,16 @@ struct xorshift_generator
template <class T> template <class T>
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed}); std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed});
return result; return result;
} }
template <class T> template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), [=] { return value; }); std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result; return result;
} }
......
...@@ -152,30 +152,4 @@ struct instruction ...@@ -152,30 +152,4 @@ struct instruction
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(&*x);
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif #endif
...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const migraphx::instruction_ref& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return migraphx::as_address(x) == migraphx::as_address(y);
}
};
} // namespace std
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment