"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4d8826503d88b2966a7d44592e632b9526578d64"
Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager; struct module_pass_manager;
struct fuse_pointwise struct MIGRAPHX_EXPORT fuse_pointwise
{ {
std::string name() const { return "fuse_pointwise"; } std::string name() const { return "fuse_pointwise"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager; struct module_pass_manager;
struct fuse_reduce struct MIGRAPHX_EXPORT fuse_reduce
{ {
std::string name() const { return "fuse_reduce"; } std::string name() const { return "fuse_reduce"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -117,20 +117,20 @@ auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) ...@@ -117,20 +117,20 @@ auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
} }
template <class T> template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) auto fill_tensor_data(const migraphx::shape& s, double value = 0)
{ {
auto result = make_shared_array<T>(s.element_space()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.element_space(), [=] { return value; }); std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result; return result;
} }
argument fill_argument(shape s, unsigned long value = 0); MIGRAPHX_EXPORT argument fill_argument(shape s, double value = 0);
argument generate_argument(shape s, unsigned long seed = 0); MIGRAPHX_EXPORT argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0); MIGRAPHX_EXPORT literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l); MIGRAPHX_EXPORT literal abs(literal l);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
...@@ -33,7 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,7 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct inline_module struct MIGRAPHX_EXPORT inline_module
{ {
std::string name() const { return "inline_module"; } std::string name() const { return "inline_module"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -38,7 +38,7 @@ struct module; ...@@ -38,7 +38,7 @@ struct module;
/** /**
* insert pads if attribute of padding is asymmetrical * insert pads if attribute of padding is asymmetrical
*/ */
struct insert_pad struct MIGRAPHX_EXPORT insert_pad
{ {
std::string name() const { return "insert_pad"; } std::string name() const { return "insert_pad"; }
......
...@@ -37,14 +37,15 @@ ...@@ -37,14 +37,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); MIGRAPHX_EXPORT shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
shape compute_shape(const operation& op, MIGRAPHX_EXPORT shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args, const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods); const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); MIGRAPHX_EXPORT std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs); MIGRAPHX_EXPORT std::vector<shape> try_compute_shape(const operation& op,
const std::vector<shape>& inputs);
struct instruction
struct MIGRAPHX_EXPORT instruction
{ {
instruction() {} instruction() {}
...@@ -136,6 +137,10 @@ struct instruction ...@@ -136,6 +137,10 @@ struct instruction
operation normalized_operator() const; operation normalized_operator() const;
std::size_t get_target_id() const;
void set_target_id(std::size_t tid);
void debug_print() const; void debug_print() const;
static void print(std::ostream& os, static void print(std::ostream& os,
...@@ -172,7 +177,8 @@ struct instruction ...@@ -172,7 +177,8 @@ struct instruction
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
std::vector<module_ref> module_args; std::vector<module_ref> module_args;
literal lit; literal lit;
bool normalized = false; bool normalized = false;
std::size_t target_id = 0;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept; MIGRAPHX_EXPORT migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,10 +31,10 @@ ...@@ -31,10 +31,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4); MIGRAPHX_EXPORT std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val); MIGRAPHX_EXPORT std::string to_json_string(const value& val);
value from_json_string(const std::string& str); MIGRAPHX_EXPORT value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size); MIGRAPHX_EXPORT value from_json_string(const char* str, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -36,7 +36,7 @@ struct module_pass_manager; ...@@ -36,7 +36,7 @@ struct module_pass_manager;
/** /**
* Transform convolutions to nhwc * Transform convolutions to nhwc
*/ */
struct layout_nhwc struct MIGRAPHX_EXPORT layout_nhwc
{ {
std::string name() const { return "layout_nhwc"; } std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -147,8 +147,8 @@ literal transform(literal l1, literal l2, F f) ...@@ -147,8 +147,8 @@ literal transform(literal l1, literal l2, F f)
return result; return result;
} }
void migraphx_to_value(value& v, const literal& l); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const literal& l);
void migraphx_from_value(const value& v, literal& l); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, literal& l);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -36,15 +36,18 @@ struct file_options ...@@ -36,15 +36,18 @@ struct file_options
std::string format = "msgpack"; std::string format = "msgpack";
}; };
program load(const std::string& filename, const file_options& options = file_options{}); MIGRAPHX_EXPORT program load(const std::string& filename,
program load_buffer(const std::vector<char>& buffer, const file_options& options = file_options{}); const file_options& options = file_options{});
program MIGRAPHX_EXPORT program load_buffer(const std::vector<char>& buffer,
load_buffer(const char* buffer, std::size_t size, const file_options& options = file_options{}); const file_options& options = file_options{});
MIGRAPHX_EXPORT program load_buffer(const char* buffer,
std::size_t size,
const file_options& options = file_options{});
void save(const program& p, MIGRAPHX_EXPORT void
const std::string& filename, save(const program& p, const std::string& filename, const file_options& options = file_options{});
const file_options& options = file_options{}); MIGRAPHX_EXPORT std::vector<char> save_buffer(const program& p,
std::vector<char> save_buffer(const program& p, const file_options& options = file_options{}); const file_options& options = file_options{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name); MIGRAPHX_EXPORT operation make_op(const std::string& name);
operation make_op(const std::string& name, MIGRAPHX_EXPORT operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v); const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v); MIGRAPHX_EXPORT operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list // A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all // cannot be passed in directly. This is to enforce at compile-time that all
...@@ -48,7 +48,7 @@ operation make_op(const std::string& name, const Value& v) ...@@ -48,7 +48,7 @@ operation make_op(const std::string& name, const Value& v)
return make_op_from_value(name, v); return make_op_from_value(name, v);
} }
operation make_json_op(const std::string& name, const std::string& s); MIGRAPHX_EXPORT operation make_json_op(const std::string& name, const std::string& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -46,7 +46,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -46,7 +46,7 @@ inline namespace MIGRAPHX_INLINE_NS {
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct marker struct MIGRAPHX_EXPORT marker
{ {
// //
void mark_start(instruction_ref ins_ref); void mark_start(instruction_ref ins_ref);
...@@ -80,7 +80,7 @@ struct marker ...@@ -80,7 +80,7 @@ struct marker
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -233,7 +233,7 @@ struct marker ...@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -31,10 +31,15 @@ ...@@ -31,10 +31,15 @@
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/source_location.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -103,6 +108,13 @@ struct predicate_matcher ...@@ -103,6 +108,13 @@ struct predicate_matcher
} }
}; };
/// Convert a predicate function into a matcher
template <class P>
predicate_matcher<P> make_predicate_matcher(P p)
{
return {p};
}
/// Convert a function into a matcher /// Convert a function into a matcher
template <class F> template <class F>
struct function_matcher struct function_matcher
...@@ -124,14 +136,14 @@ template <class M> ...@@ -124,14 +136,14 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[=, name = std::move(name)](matcher_context& ctx, [=, m_name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result) if(result)
{ {
if(not ctx.has_instruction(ins)) if(not ctx.has_instruction(ins))
return nullopt; return nullopt;
ctx.instructions[name] = ins; ctx.instructions[m_name] = ins;
} }
return result; return result;
}); });
...@@ -183,14 +195,26 @@ struct id_matcher ...@@ -183,14 +195,26 @@ struct id_matcher
template <class M> template <class M>
struct basic_matcher; struct basic_matcher;
struct any_matcher;
template <class M> template <class M>
basic_matcher<M> make_basic_matcher(M m); struct type_erased_matcher
{
#if MIGRAPHX_USE_TYPE_ERASED_MATCHERS
using type = any_matcher;
#else
using type = basic_matcher<M>;
#endif
};
template <class M>
typename type_erased_matcher<M>::type make_basic_matcher(M m);
template <class F> template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f); auto make_basic_fun_matcher(F f);
template <class P> template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p); auto make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
...@@ -222,38 +246,38 @@ struct basic_matcher ...@@ -222,38 +246,38 @@ struct basic_matcher
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
}; };
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
template <class M> template <class M>
basic_matcher<M> make_basic_matcher(M m) typename type_erased_matcher<M>::type make_basic_matcher(M m)
{ {
return {m}; return {m};
} }
/// Create a basic matcher from a function /// Create a basic matcher from a function
template <class F> template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f) auto make_basic_fun_matcher(F f)
{ {
return {{f}}; return make_basic_matcher(make_function_matcher(f));
} }
/// Create a basic matcher from a predicate function /// Create a basic matcher from a predicate function
template <class P> template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) auto make_basic_pred_matcher(P p)
{ {
return {{p}}; return make_basic_matcher(make_predicate_matcher(p));
} }
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher /// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
...@@ -347,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -347,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m)
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 const int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
const const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
#endif const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{});
int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool trace_for = not trace_filter.empty() and
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 (contains(std::string{location.file_name()}, trace_filter) or
const contains(std::string{location.function_name()}, trace_filter));
#endif bool match = false;
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1) if(trace > 1 or trace_for)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0) if(trace > 0 or trace_for)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
...@@ -397,13 +420,19 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) ...@@ -397,13 +420,19 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
/// Find matches in a module /// Find matches in a module
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, Ms&&... ms) struct find_matches
{ {
for(auto ins : iterator_for(get_module(mod))) find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current())
{ {
find_matches(mod, ins, ms...); for(auto ins : iterator_for(get_module(mod)))
{
find_matches_for(location, mod, ins, ms...);
}
} }
} };
template <class Mod, class... Ms>
find_matches(Mod& mod, Ms&&... ms) -> find_matches<Mod, Ms...>;
template <class M, class F> template <class M, class F>
struct find_generic_match struct find_generic_match
...@@ -632,9 +661,9 @@ auto skip_output(Ms... ms) ...@@ -632,9 +661,9 @@ auto skip_output(Ms... ms)
inline auto var(std::string s) inline auto var(std::string s)
{ {
return make_basic_fun_matcher( return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx, [=, m_s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> { instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s); auto it = ctx.instructions.find(m_s);
if(it == ctx.instructions.end()) if(it == ctx.instructions.end())
return nullopt; return nullopt;
return it->second; return it->second;
...@@ -644,7 +673,7 @@ inline auto var(std::string s) ...@@ -644,7 +673,7 @@ inline auto var(std::string s)
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; }); [=, m_s = std::move(s)](instruction_ref ins) { return ins->name() == m_s; });
} }
inline auto name_contains(const std::string& name) inline auto name_contains(const std::string& name)
...@@ -655,8 +684,8 @@ inline auto name_contains(const std::string& name) ...@@ -655,8 +684,8 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) { return make_basic_pred_matcher([=, m_names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0; return m_names.count(ins->name()) > 0;
}); });
} }
......
...@@ -36,7 +36,7 @@ struct module; ...@@ -36,7 +36,7 @@ struct module;
* Remove multiple memory allocations using graph coloring to find memory allocations that can be * Remove multiple memory allocations using graph coloring to find memory allocations that can be
* reused. * reused.
*/ */
struct memory_coloring struct MIGRAPHX_EXPORT memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
......
...@@ -52,7 +52,7 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins ...@@ -52,7 +52,7 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
/** /**
* @brief Stores the instruction stream * @brief Stores the instruction stream
*/ */
struct module struct MIGRAPHX_EXPORT module
{ {
module(const std::string& name = ""); module(const std::string& name = "");
...@@ -189,7 +189,7 @@ struct module ...@@ -189,7 +189,7 @@ struct module
instruction_ref validate() const; instruction_ref validate() const;
instruction_ref find_dangling_reference() const; instruction_ref find_dangling_reference() const;
void finalize(context& ctx); void finalize(std::vector<context>& contexts);
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
...@@ -222,11 +222,21 @@ struct module ...@@ -222,11 +222,21 @@ struct module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
*/
module& sort(); module& sort();
/* Any instruction "X" can have module arguments and those modules inside them can use any other
* instruction "Y" from predecessor modules of the instruction "X". Such instruction "Y" inside
* module args are not listed as input instructions to "X". But those instructions "Y" must be
* evaluted before the instruction "X" can. Therefore such "Y" instructions are considered
* implicit dependency to "X".
*/
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private: private:
......
...@@ -31,10 +31,11 @@ ...@@ -31,10 +31,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer); MIGRAPHX_EXPORT void to_msgpack(const value& v,
std::vector<char> to_msgpack(const value& v); std::function<void(const char*, std::size_t)> writer);
value from_msgpack(const std::vector<char>& buffer); MIGRAPHX_EXPORT std::vector<char> to_msgpack(const value& v);
value from_msgpack(const char* buffer, std::size_t size); MIGRAPHX_EXPORT value from_msgpack(const std::vector<char>& buffer);
MIGRAPHX_EXPORT value from_msgpack(const char* buffer, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -42,7 +42,8 @@ struct select_dependent_type ...@@ -42,7 +42,8 @@ struct select_dependent_type
template <class T, class... Ts> template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens); MIGRAPHX_EXPORT
bool normalize_attributes(operation& op, const shape& input_shape);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -39,7 +39,7 @@ struct module; ...@@ -39,7 +39,7 @@ struct module;
* Process negative axis attributes of ops * Process negative axis attributes of ops
*/ */
struct normalize_ops struct MIGRAPHX_EXPORT normalize_ops
{ {
std::string name() const { return "normalize_ops"; } std::string name() const { return "normalize_ops"; }
void apply(module& m) const; void apply(module& m) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment