"vscode:/vscode.git/clone" did not exist on "0490e86048f3314e6be1353f99f9b2f9b7370d7d"
Commit 3885c9bc authored by mei-ye's avatar mei-ye
Browse files

merge in develop

parent a5b0afa0
...@@ -21,6 +21,9 @@ add_library(migraphx ...@@ -21,6 +21,9 @@ add_library(migraphx
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
opt/pre_scheduling.cpp
opt/pre_scheduling_impl.cpp
opt/dom_info.cpp
) )
rocm_clang_tidy_check(migraphx) rocm_clang_tidy_check(migraphx)
rocm_install_targets( rocm_install_targets(
......
...@@ -23,7 +23,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -23,7 +23,6 @@ inline namespace MIGRAPHX_INLINE_NS {
#else #else
#define MIGRAPHX_DEBUG(s) #define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT #endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -20,7 +20,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,7 +20,11 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context struct context
{ {
/// Wait for any tasks in the context to complete /// Wait for any tasks in the context to complete
void finish() const; void finish();
void set_stream(int ndx);
void create_events(int num_of_events);
void record_event(int event);
void wait_event(int event);
}; };
#else #else
...@@ -30,7 +34,11 @@ struct context ...@@ -30,7 +34,11 @@ struct context
* *
* struct context * struct context
* { * {
* void finish() const; * void finish() ;
* void set_stream(int input) ;
* void create_events(int input) ;
* void record_event(int input) ;
* void wait_event(int input) ;
* }; * };
* *
*/ */
...@@ -92,12 +100,36 @@ struct context ...@@ -92,12 +100,36 @@ struct context
return private_detail_te_get_handle().type(); return private_detail_te_get_handle().type();
} }
void finish() const void finish()
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
} }
void set_stream(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().set_stream(input);
}
void create_events(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().create_events(input);
}
void record_event(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record_event(input);
}
void wait_event(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_event(input);
}
friend bool is_shared(const context& private_detail_x, const context& private_detail_y) friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
{ {
return private_detail_x.private_detail_te_handle_mem_var == return private_detail_x.private_detail_te_handle_mem_var ==
...@@ -111,7 +143,11 @@ struct context ...@@ -111,7 +143,11 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void finish() const = 0; virtual void finish() = 0;
virtual void set_stream(int input) = 0;
virtual void create_events(int input) = 0;
virtual void record_event(int input) = 0;
virtual void wait_event(int input) = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -142,7 +178,15 @@ struct context ...@@ -142,7 +178,15 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { private_detail_te_value.finish(); } void finish() override { private_detail_te_value.finish(); }
void set_stream(int input) override { private_detail_te_value.set_stream(input); }
void create_events(int input) override { private_detail_te_value.create_events(input); }
void record_event(int input) override { private_detail_te_value.record_event(input); }
void wait_event(int input) override { private_detail_te_value.wait_event(input); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
...@@ -210,7 +254,6 @@ inline const ValueType& any_cast(const context& x) ...@@ -210,7 +254,6 @@ inline const ValueType& any_cast(const context& x)
} }
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#include <migraphx/common_header.hpp>
#include <migraphx/set_operator.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Compute dominators, post-dominators, dominator tree, post-dominator tree
// for instructions with streams. Also do program analysis to identify
// concurrent instructions in different streams.
struct dom_info
{
dom_info(program* p) : p_program(p)
{
instr2_idom.clear();
instr2_ipdom.clear();
}
void compute_dom(bool);
void
find_dom_tree(std::unordered_map<const instruction*, std::set<const instruction*>>& instr2_doms,
const instruction* p_ins,
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree,
std::unordered_map<const instruction*, const instruction*>& idom);
#ifdef MIGRAPHX_DEBUG_OPT
void dump_doms(std::unordered_map<const instruction*, int>&, bool);
#endif
bool is_split_point(instruction_ref ins);
bool is_merge_point(instruction_ref ins);
// whether ins1 strictly post-dominates ins2.
bool strictly_post_dominates(const instruction* ins1, const instruction* ins2);
// Program analysis to identify concurrent instructions.
void propagate_splits(
int num_of_streams,
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
program* p_program;
// map instruction to its immediate dominator.
std::unordered_map<const instruction*, const instruction*> instr2_idom;
// map instruction to its immediate post dominator.
std::unordered_map<const instruction*, const instruction*> instr2_ipdom;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_FIND_CONCUR_HPP
#define MIGRAPHX_GUARD_FIND_CONCUR_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <unordered_map>
#include <vector>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent analysis to find concurrent instructions
/// executing in different streams.
struct find_concur
{
void get_concur(program* p,
int num_of_streams,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
} const;
#else
/*
* Type-erased interface for:
*
* struct find_concur
* {
* void get_concur(program* p,int num_of_stream,std::unordered_map<const instruction*,
* std::vector<std::vector<const instruction*>>>& concur_instrs,std::unordered_map<const
* instruction*, int>& input) const;
* };
*
*/
struct find_concur
{
// Constructors
find_concur() = default;
template <typename PrivateDetailTypeErasedT>
find_concur(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
find_concur& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().get_concur(p, num_of_stream, concur_instrs, input);
}
friend bool is_shared(const find_concur& private_detail_x, const find_concur& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const override
{
private_detail_te_value.get_concur(p, num_of_stream, concur_instrs, input);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(find_concur& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const find_concur& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#define MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent instruction insertion.
/// for multi-stream execution.
struct insert_instruction
{
void insert_create_events(program* p, instruction_ref ins, int num_of_events);
void insert_record_event(program* p, instruction_ref ins, int event);
void insert_wait_event(program* p, instruction_ref ins, int event);
void insert_stream(program* p, instruction_ref ins, int stream);
};
#else
/*
* Type-erased interface for:
*
* struct insert_instruction
* {
* void insert_create_events(program* p,instruction_ref ins,int input) ;
* void insert_record_event(program* p,instruction_ref ins,int input) ;
* void insert_wait_event(program* p,instruction_ref ins,int input) ;
* void insert_stream(program* p,instruction_ref ins,int input) ;
* };
*
*/
struct insert_instruction
{
// Constructors
insert_instruction() = default;
template <typename PrivateDetailTypeErasedT>
insert_instruction(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
insert_instruction& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void insert_create_events(program* p, instruction_ref ins, int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_create_events(p, ins, input);
}
void insert_record_event(program* p, instruction_ref ins, int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_record_event(p, ins, input);
}
void insert_wait_event(program* p, instruction_ref ins, int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_wait_event(p, ins, input);
}
void insert_stream(program* p, instruction_ref ins, int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_stream(p, ins, input);
}
friend bool is_shared(const insert_instruction& private_detail_x,
const insert_instruction& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void insert_create_events(program* p, instruction_ref ins, int input) = 0;
virtual void insert_record_event(program* p, instruction_ref ins, int input) = 0;
virtual void insert_wait_event(program* p, instruction_ref ins, int input) = 0;
virtual void insert_stream(program* p, instruction_ref ins, int input) = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void insert_create_events(program* p, instruction_ref ins, int input) override
{
private_detail_te_value.insert_create_events(p, ins, input);
}
void insert_record_event(program* p, instruction_ref ins, int input) override
{
private_detail_te_value.insert_record_event(p, ins, input);
}
void insert_wait_event(program* p, instruction_ref ins, int input) override
{
private_detail_te_value.insert_wait_event(p, ins, input);
}
void insert_stream(program* p, instruction_ref ins, int input) override
{
private_detail_te_value.insert_stream(p, ins, input);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const insert_instruction* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(insert_instruction* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(insert_instruction& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const insert_instruction& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -16,6 +16,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -16,6 +16,12 @@ inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
enum instruction_mask : unsigned int
{
record_event = 0,
wait_event = 1
};
struct instruction struct instruction
{ {
instruction() {} instruction() {}
...@@ -41,6 +47,17 @@ struct instruction ...@@ -41,6 +47,17 @@ struct instruction
const operation& get_operator() const; const operation& get_operator() const;
int get_stream() const;
void set_stream(int);
int get_event() const;
void set_event(int);
void add_mask(instruction_mask m)
{
if((mask & (1u << m)) == 0)
mask += (1u << m);
}
bool has_mask(instruction_mask m) const { return ((mask & (1u << m)) != 0); }
std::string name() const; std::string name() const;
const std::vector<instruction_ref>& inputs() const; const std::vector<instruction_ref>& inputs() const;
...@@ -94,6 +111,9 @@ struct instruction ...@@ -94,6 +111,9 @@ struct instruction
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
int stream = -1;
unsigned int mask = 0;
int event = -1;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -109,7 +129,6 @@ struct hash<migraphx::instruction_ref> ...@@ -109,7 +129,6 @@ struct hash<migraphx::instruction_ref>
return std::hash<migraphx::instruction*>{}(&*x); return std::hash<migraphx::instruction*>{}(&*x);
} }
}; };
} // namespace std } // namespace std
#endif #endif
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/find_concur.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -15,11 +17,12 @@ struct program; ...@@ -15,11 +17,12 @@ struct program;
struct memory_coloring struct memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
int num_of_streams = 0;
find_concur f_concur;
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -9,6 +9,8 @@ namespace migraphx { ...@@ -9,6 +9,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PRE_SCHEDULING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EVENT_AS_INSTRUCTION)
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/insert_instruction.hpp>
namespace migraphx {
struct pre_scheduling
{
std::function<std::pair<int, int>(const operation&)> weight_func;
int num_of_streams;
insert_instruction insert_instr;
bool verify = false;
std::string name() const { return "pre scheduling"; }
void apply(program& p) const;
};
} // namespace migraphx
#endif
...@@ -98,7 +98,7 @@ struct program ...@@ -98,7 +98,7 @@ struct program
void compile(const target& t, tracer trace = tracer{}); void compile(const target& t, tracer trace = tracer{});
void finalize(); void finalize();
void finish();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print() const; void debug_print() const;
...@@ -114,7 +114,6 @@ struct program ...@@ -114,7 +114,6 @@ struct program
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_intersection(const Set& lhs, const Set& rhs)
{
if(lhs.size() <= rhs.size())
{
Set iset;
for(const Key& key : lhs)
{
if(rhs.count(key) > 0)
{
iset.insert(key);
}
}
return std::move(iset);
}
else
{
return set_intersection(rhs, lhs);
}
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_union(const Set& lhs, const Set& rhs)
{
Set uset{lhs};
uset.insert(rhs.begin(), rhs.end());
return std::move(uset);
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_difference(const Set& lhs, const Set& rhs)
{
Set dset{lhs};
for(auto& iter : rhs)
{
dset.erase(iter);
}
return std::move(dset);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -87,6 +87,11 @@ const literal& instruction::get_literal() const ...@@ -87,6 +87,11 @@ const literal& instruction::get_literal() const
return lit; return lit;
} }
int instruction::get_stream() const { return stream; }
void instruction::set_stream(int s) { stream = s; }
int instruction::get_event() const { return event; }
void instruction::set_event(int e) { event = e; }
const operation& instruction::get_operator() const { return op; } const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); } std::string instruction::name() const { return op.name(); }
...@@ -211,6 +216,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg ...@@ -211,6 +216,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
{ {
return op.compute_shape(to_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -31,6 +31,7 @@ int main(int argc, char const* argv[]) ...@@ -31,6 +31,7 @@ int main(int argc, char const* argv[])
std::cout << "Allocating params ... " << std::endl; std::cout << "Allocating params ... " << std::endl;
auto m = create_param_map(p); auto m = create_param_map(p);
std::cout << "Running performance report ... " << std::endl; std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m); p.perf_report(std::cout, n, m);
} }
} }
#include <migraphx/dom_info.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// A unified interface to visit programs top-down or bottom-up.
struct program_visitor
{
program* p_program;
bool reversed;
instruction_ref begin() { return reversed ? std::prev(p_program->end()) : p_program->begin(); }
instruction_ref end() { return reversed ? p_program->begin() : std::prev(p_program->end()); }
instruction_ref next(instruction_ref ins) { return reversed ? std::prev(ins) : std::next(ins); }
const std::vector<instruction_ref>& get_inputs(instruction_ref ins)
{
return reversed ? ins->outputs() : ins->inputs();
}
};
// Query whether ins1 strictly post-dominates ins2. ins1 strictly post-dominates
// ins2 if ins1 post-dominates ins2 and ins1 is not ins2.
//
bool dom_info::strictly_post_dominates(const instruction* ins1, const instruction* ins2)
{
if(ins1 != ins2)
{
const instruction* iter = ins2;
while(instr2_ipdom.find(iter) != instr2_ipdom.end())
{
if(ins1 == instr2_ipdom[iter])
return true;
iter = instr2_ipdom[iter];
}
}
return false;
}
// Among p_ins's dominators, find ones that strictly dominates or post-dominators others.
//
void dom_info::find_dom_tree(
std::unordered_map<const instruction*, std::set<const instruction*>>& instr2_doms,
const instruction* p_ins,
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree,
std::unordered_map<const instruction*, const instruction*>& idom)
{
for(auto& iter1 : instr2_doms[p_ins])
{
auto dom_check = [& dom_tree = idom, ins1 = iter1 ](const instruction* ins2)
{
if(ins1 == ins2)
return false;
const instruction* iter = ins2;
;
while(dom_tree.find(iter) != dom_tree.end())
{
if(ins1 == dom_tree[iter])
return true;
iter = dom_tree[iter];
}
return false;
};
// check whether iter1 strictly dominates or post-dominates any other notes in
// p_ins's dominators or post-dominators.
if(!std::any_of(instr2_doms[p_ins].begin(), instr2_doms[p_ins].end(), dom_check))
{
assert(instr2_dom_tree.find(p_ins) == instr2_dom_tree.end());
instr2_dom_tree[p_ins] = iter1;
}
}
}
// Compute dominator or post-dominator. Instructions that do not use
// streams are left out.
//
void dom_info::compute_dom(bool reversed)
{
std::size_t num_of_instrs = p_program->size();
if(num_of_instrs == 0)
return;
std::unordered_map<const instruction*, std::set<const instruction*>> instr2_doms;
std::unordered_map<const instruction*, int> instr2_points;
int cur_points = reversed ? num_of_instrs - 1 : 0;
bool seen_stream = false;
program_visitor vis{p_program, reversed};
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree =
(reversed ? instr2_ipdom : instr2_idom);
for(auto ins = vis.begin(), end = vis.end();; ins = vis.next(ins))
{
const instruction* p_ins = &(*ins);
instr2_points[p_ins] = cur_points;
if(ins->get_stream() < 0)
{
if(reversed)
cur_points--;
else
cur_points++;
;
if(ins == end)
break;
continue;
}
seen_stream = true;
const instruction* p_tmp = nullptr;
int cnt = 0;
// find dominators.
for(auto&& iter : vis.get_inputs(ins))
{
if(iter->get_stream() < 0)
continue;
const instruction* p_arg = &(*iter);
cnt++;
assert(instr2_doms.find(p_arg) != instr2_doms.end());
if(p_tmp == nullptr)
instr2_doms[p_ins] = instr2_doms[p_arg];
else
instr2_doms[p_ins] = set_intersection(instr2_doms[p_ins], instr2_doms[p_arg]);
p_tmp = p_arg;
}
// find immediate dominators.
if(cnt == 1)
{
instr2_dom_tree[p_ins] = p_tmp;
}
else if(cnt > 0)
{
std::unordered_map<const instruction*, const instruction*>& idom =
reversed ? instr2_ipdom : instr2_idom;
find_dom_tree(instr2_doms, p_ins, instr2_dom_tree, idom);
}
instr2_doms[p_ins].insert(p_ins);
if(ins == end)
break;
if(reversed)
cur_points--;
else
cur_points++;
}
if(seen_stream)
{
MIGRAPHX_DEBUG(dump_doms(instr2_points, reversed));
}
}
// Identify split points. A split point has more than one
// outputs that are executed in different streams.
bool dom_info::is_split_point(instruction_ref ins)
{
if(ins->has_mask(record_event))
{
std::set<int> stream_set;
for(auto&& arg : ins->outputs())
{
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
}
if(stream_set.size() > 1)
return true;
}
return false;
}
// Identify merge points. A merge point has more than one
// inputs that are executed in different streams.
bool dom_info::is_merge_point(instruction_ref ins)
{
if(ins->has_mask(wait_event))
{
std::set<int> stream_set;
for(auto&& arg : ins->inputs())
{
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
}
if(stream_set.size() > 1)
return true;
}
return false;
}
// Propagate split points through the graph and identify concurrent instructions.
// Concurrent instructions have the same split points and different streams.
//
void dom_info::propagate_splits(
int num_of_streams,
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points)
{
std::unordered_map<instruction_ref, bool> is_split;
std::unordered_map<instruction_ref, bool> is_merge;
std::unordered_map<instruction_ref, std::set<const instruction*>> split_from;
int cur_points = 0;
instr2_points.clear();
for(auto ins : iterator_for(*p_program))
{
const instruction* p_iter = &(*ins);
instr2_points[p_iter] = cur_points++;
int stream = ins->get_stream();
if(stream < 0)
continue;
is_split[ins] = is_split_point(ins);
is_merge[ins] = is_merge_point(ins);
for(auto&& arg : ins->inputs())
{
// Input is a split point.
if(is_split.find(arg) != is_split.end())
split_from[ins].insert(&(*arg));
// Union inputs' split points.
if((split_from.find(arg) != split_from.end()) && !split_from[arg].empty())
{
if(split_from.find(ins) == split_from.end())
split_from[ins] = split_from[arg];
else
split_from[ins] = set_union(split_from[ins], split_from[arg]);
}
}
if(is_merge[ins])
{
assert(split_from.find(ins) != split_from.end());
std::set<const instruction*> del_set;
// post-dominator kills split point.
for(auto& split : split_from[ins])
{
if(strictly_post_dominates(p_iter, split))
del_set.insert(split);
}
split_from[ins] = set_difference(split_from[ins], del_set);
}
if(split_from.find(ins) != split_from.end())
{
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
{
if(concur_instrs.find(split) == concur_instrs.end())
{
std::vector<std::vector<const instruction*>> instr_stack;
instr_stack.resize(num_of_streams);
concur_instrs[split] = instr_stack;
}
concur_instrs[split][stream].push_back(p_iter);
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void dom_info::dump_doms(std::unordered_map<const instruction*, int>& instr2_points, bool post_dom)
{
std::cout << "---dominator tree---" << std::endl;
for(auto ins : iterator_for(*p_program))
{
const instruction* p_ins = &(*ins);
if(!post_dom && (instr2_idom.find(p_ins) != instr2_idom.end()))
{
const instruction* idom = instr2_idom[p_ins];
std::cout << "@" << instr2_points[p_ins] << " imm dominator: "
<< "@" << instr2_points[idom] << std::endl;
}
if(post_dom && (instr2_ipdom.find(p_ins) != instr2_ipdom.end()))
{
const instruction* ipdom = instr2_ipdom[p_ins];
std::cout << "@" << instr2_points[p_ins] << " imm post domimator: "
<< "@" << instr2_points[ipdom] << std::endl;
}
}
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,10 +8,9 @@ void memory_coloring::apply(program& p) const ...@@ -8,10 +8,9 @@ void memory_coloring::apply(program& p) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&p, allocation_op, verify, num_of_streams, f_concur);
opt.run(); opt.run();
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -11,6 +11,8 @@ void memory_coloring_impl::run() ...@@ -11,6 +11,8 @@ void memory_coloring_impl::run()
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
MIGRAPHX_DEBUG(dump_intervals()); MIGRAPHX_DEBUG(dump_intervals());
if(num_of_streams > 0)
add_stream_conflicts();
// Coloring // Coloring
while(!alloc_queue.empty()) while(!alloc_queue.empty())
{ {
...@@ -152,7 +154,8 @@ void memory_coloring_impl::build() ...@@ -152,7 +154,8 @@ void memory_coloring_impl::build()
interval->segment.end = cur_points; interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number; interval->segment.vn = ++max_value_number;
interval->add_use(cur_points); interval->add_use(cur_points);
instr2_live[p_arg] = interval; instr2_live[p_arg] = interval;
instr2_live[&(*arg)] = interval;
add_conflicts(live_set, max_value_number); add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number); live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment); live_ranges[max_value_number] = &(interval->segment);
...@@ -165,6 +168,7 @@ void memory_coloring_impl::build() ...@@ -165,6 +168,7 @@ void memory_coloring_impl::build()
interval_ptr interval = instr2_live[p_arg]; interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points); interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end()); assert(live_set.find(interval->id) != live_set.end());
instr2_live[&(*arg)] = interval;
} }
} }
if(is_dead) if(is_dead)
...@@ -259,6 +263,57 @@ void memory_coloring_impl::verify() ...@@ -259,6 +263,57 @@ void memory_coloring_impl::verify()
} }
} }
// Add conflicts of concurrent instructions into conflict table.
//
void memory_coloring_impl::add_stream_conflicts(std::vector<const instruction*>& i1,
std::vector<const instruction*>& i2)
{
for(auto& ins1 : i1)
{
if(instr2_live.find(ins1) == instr2_live.end())
continue;
interval_ptr interval1 = instr2_live[ins1];
int id1 = interval1->id;
for(auto& ins2 : i2)
{
if(instr2_live.find(ins2) == instr2_live.end())
continue;
interval_ptr interval2 = instr2_live[ins2];
int id2 = interval2->id;
conflict_table[id1].insert(id2);
conflict_table[id2].insert(id1);
#ifdef MIGRAPHX_DEBUG_OPT
std::cout << "@" << instr2_points[ins1] << " id:" << id1 << " => "
<< "@" << instr2_points[ins2] << " id:" << id2 << std::endl;
#endif
}
}
}
// Identify concurrent instructions in different streams and add conflicts to
// conflict table.
//
void memory_coloring_impl::add_stream_conflicts()
{
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>
concur_instrs;
f_concur.get_concur(p_program, num_of_streams, concur_instrs, instr2_points);
MIGRAPHX_DEBUG(dump_concur_instrs(concur_instrs));
for(auto& iter : concur_instrs)
{
for(auto s1 = 0; s1 < num_of_streams; ++s1)
{
std::vector<const instruction*>& i1 = iter.second[s1];
for(auto s2 = s1 + 1; s2 < num_of_streams; ++s2)
{
std::vector<const instruction*>& i2 = iter.second[s2];
add_stream_conflicts(i1, i2);
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; } void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
...@@ -290,6 +345,29 @@ void memory_coloring_impl::dump_intervals() ...@@ -290,6 +345,29 @@ void memory_coloring_impl::dump_intervals()
} }
} }
void memory_coloring_impl::dump_concur_instrs(
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs)
{
for(auto iter = concur_instrs.begin(), end = concur_instrs.end(); iter != end; ++iter)
{
std::cout << "concurrent instructions for split @" << instr2_points[iter->first]
<< std::endl;
for(auto s1 = 0; s1 < num_of_streams; ++s1)
{
std::vector<const instruction*>& instrs = iter->second[s1];
if(instrs.empty())
continue;
std::cout << "stream:" << s1 << std::endl;
for(auto ins = instrs.begin(), ins_end = instrs.end(); ins != ins_end; ++ins)
{
std::cout << " @" << instr2_points[*ins];
}
std::cout << std::endl;
}
}
}
// map liveness tracking point to instruction enum. // map liveness tracking point to instruction enum.
static int get_ins_enum(int x) static int get_ins_enum(int x)
{ {
...@@ -332,6 +410,5 @@ void live_interval::dump() ...@@ -332,6 +410,5 @@ void live_interval::dump()
} }
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp" #include <migraphx/common_header.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/find_concur.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -52,8 +53,12 @@ using interval_ptr = live_interval*; ...@@ -52,8 +53,12 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify) memory_coloring_impl(program* p, std::string alloc_op, bool p_verify, int num, find_concur f)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) : p_program(p),
allocation_op(std::move(alloc_op)),
enable_verify(p_verify),
num_of_streams(num),
f_concur(std::move(f))
{ {
instr2_live.clear(); instr2_live.clear();
live_ranges.clear(); live_ranges.clear();
...@@ -74,6 +79,8 @@ struct memory_coloring_impl ...@@ -74,6 +79,8 @@ struct memory_coloring_impl
conflict_table[val].insert(iter); conflict_table[val].insert(iter);
} }
} }
void add_stream_conflicts();
void add_stream_conflicts(std::vector<const instruction*>&, std::vector<const instruction*>&);
void build(); void build();
void run(); void run();
void rewrite(); void rewrite();
...@@ -105,6 +112,8 @@ struct memory_coloring_impl ...@@ -105,6 +112,8 @@ struct memory_coloring_impl
void dump(const std::string&); void dump(const std::string&);
void dump_program(); void dump_program();
void dump_intervals(); void dump_intervals();
void dump_concur_instrs(
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&);
#endif #endif
struct ordering struct ordering
{ {
...@@ -130,6 +139,7 @@ struct memory_coloring_impl ...@@ -130,6 +139,7 @@ struct memory_coloring_impl
return (i1->offset > i2->offset); return (i1->offset > i2->offset);
} }
}; };
program* p_program; program* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live; std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals. // universe of live intervals.
...@@ -140,7 +150,7 @@ struct memory_coloring_impl ...@@ -140,7 +150,7 @@ struct memory_coloring_impl
std::unordered_map<int, std::set<int>> conflict_table; std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring. // Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue; std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue;
std::unordered_map<const instruction*, int> instr2_points;
int num_of_lives; int num_of_lives;
int max_value_number; int max_value_number;
long long required_bytes; long long required_bytes;
...@@ -152,8 +162,9 @@ struct memory_coloring_impl ...@@ -152,8 +162,9 @@ struct memory_coloring_impl
bool unify_literals; bool unify_literals;
std::string allocation_op{}; std::string allocation_op{};
bool enable_verify; bool enable_verify;
int num_of_streams;
find_concur f_concur;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#include <migraphx/pre_scheduling.hpp>
#include "pre_scheduling_impl.hpp"
namespace migraphx {
void pre_scheduling::apply(program& p) const
{
if(!enabled(MIGRAPHX_DISABLE_PRE_SCHEDULING{}))
{
pre_scheduling_impl opt(&p, weight_func, num_of_streams, insert_instr, verify);
opt.run();
}
}
} // namespace migraphx
#include "pre_scheduling_impl.hpp"
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <stack>
namespace migraphx {
// Compute accumulated weights for each node in the DAG. Collect exit nodes
// and sort them according to accumulated weights.
//
void pre_scheduling_impl::compute_weights()
{
int ndx = 0;
std::unordered_map<dag_node*, bool> visited;
for(auto ins : iterator_for(*p_program))
{
dag_node& node = nodes[ndx];
std::pair<int, int> weight = weight_func(ins->get_operator());
node.weight = weight.first;
node.run_on_cpu = weight.second;
node.weight_sum += node.weight;
visited.clear();
for(auto&& arg : ins->inputs())
{
assert(instr2_node.find(arg) != instr2_node.end());
dag_node* def_node = instr2_node[arg];
if(visited.find(def_node) == visited.end())
{
node.weight_sum += def_node->weight_sum;
visited[def_node] = true;
}
}
if(ins->outputs().empty())
{
exit_nodes.push_back(&node);
}
node.ins = ins;
node.ins_ndx = ndx++;
instr2_node[ins] = &node;
}
int size = exit_nodes.size();
if(size > 1)
{
std::sort(exit_nodes.begin(), exit_nodes.end(), compare_exit_nodes);
}
}
// Do topology sort according to accumulated weight. Identify critial paths.
// Schedule nodes into streams. Reorder instructions according to topological
// order and annoate streams and events in the instructions.
//
void pre_scheduling_impl::reorder()
{
std::list<dag_node*> sorted_nodes;
std::stack<dag_node*> stack;
std::priority_queue<dag_node*, std::vector<dag_node*>, weighted_topology_ordering> child_queue;
std::unordered_map<dag_node*, bool> visited;
std::unordered_map<dag_node*, bool> dequeued;
for(auto&& node : exit_nodes)
{
stack.push(node);
node->partition = partition_info.create_partition();
partition_info.add_weight(node);
while(!stack.empty())
{
auto cur = stack.top();
if(dequeued.find(cur) != dequeued.end())
{
stack.pop();
continue;
}
else if((visited.find(cur) != visited.end()) || cur->ins->inputs().empty())
{
stack.pop();
sorted_nodes.push_back(cur);
dequeued[cur] = true;
continue;
}
// sort child nodes.
for(auto&& arg : cur->ins->inputs())
{
dag_node* child_node = instr2_node[arg];
if(dequeued.find(child_node) == dequeued.end())
{
child_queue.push(child_node);
}
}
// Last item in queue is on critical path.
while(!child_queue.empty())
{
dag_node* child = child_queue.top();
stack.push(child);
child_queue.pop();
if(child->weight_sum < min_partition_threshold)
child->partition = cur->partition;
else if(!child_queue.empty())
child->partition = partition_info.create_partition();
else
{
cur->first_child = child;
child->partition = cur->partition;
}
partition_info.add_weight(child);
}
visited[cur] = true;
}
}
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After weighted topology sort---"));
MIGRAPHX_DEBUG(dump(sorted_nodes));
#endif
schedule(sorted_nodes);
splice(sorted_nodes);
annotate(sorted_nodes);
if(enable_verify)
verify();
}
// Assign stream to nodes according to load balance.
//
int pre_scheduling_impl::get_stream(stream_info& info, dag_node* node)
{
int max_cycle = info.max_cycle;
if(max_cycle == 0)
return 0;
int partition_load = partition_info.weight_sum[node->partition];
int earliest_cycle = node->earliest_cycle;
int min_cycle = -1;
int min_cycle_stream = -1;
for(auto stream = 0; stream < num_of_streams; ++stream)
{
int cycle = std::max(info.next_cycles[stream], earliest_cycle);
if((cycle < max_cycle) && ((max_cycle - cycle) > partition_load))
return stream;
if((min_cycle_stream == -1) || (cycle < min_cycle))
{
min_cycle = cycle;
min_cycle_stream = stream;
}
}
return min_cycle_stream;
}
// Record the stream-assignment.
//
void pre_scheduling_impl::record(stream_info& info, dag_node* node)
{
int stream = node->stream;
int next_cycle = info.next_cycles[stream];
node->sched_cycle = std::max(node->earliest_cycle, next_cycle);
next_cycle = node->sched_cycle + node->weight;
info.next_cycles[stream] = next_cycle;
info.max_cycle = std::max(info.max_cycle, next_cycle);
for(auto&& arg : node->ins->outputs())
{
assert(instr2_node.find(arg) != instr2_node.end());
dag_node* use_node = instr2_node[arg];
use_node->earliest_cycle = std::max(use_node->earliest_cycle, next_cycle);
}
if(node->can_use_stream())
instr2_stream[node->ins] = stream;
}
// Assign nodes to streams.
//
void pre_scheduling_impl::schedule(std::list<dag_node*>& sorted_nodes)
{
if(num_of_streams == 0)
return;
stream_info info(num_of_streams);
std::unordered_map<int, int> partition2_stream;
partition2_stream.clear();
for(auto&& node : sorted_nodes)
{
int cur_partition = node->partition;
assert(cur_partition >= 0);
if(partition2_stream.find(cur_partition) != partition2_stream.end())
{
node->stream = partition2_stream[cur_partition];
}
else
{
node->stream = get_stream(info, node);
}
assert(node->stream >= 0);
record(info, node);
partition2_stream[cur_partition] = node->stream;
}
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After assigning stream---"));
MIGRAPHX_DEBUG(dump(sorted_nodes));
#endif
}
// Reorder the instructions ino topological order.
//
void pre_scheduling_impl::splice(std::list<dag_node*>& sorted_nodes)
{
if(sorted_nodes.size() <= 1)
return;
auto begin = sorted_nodes.begin();
auto iter = sorted_nodes.end();
instruction_ref insert_before = (*(--iter))->ins;
do
{
iter--;
insert_before = p_program->move_instruction((*iter)->ins, insert_before);
} while(iter != begin);
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After splice in pre-scheduling---"));
MIGRAPHX_DEBUG(dump_program());
#endif
}
// Annotate streams and events in the instruction. Insert set_stream
// instructions.
//
void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes)
{
int event = 0;
int last_stream = -1;
bool enable_event_as_instr = enabled(MIGRAPHX_ENABLE_EVENT_AS_INSTRUCTION{});
for(auto&& node : sorted_nodes)
{
instruction_ref ins = node->ins;
if(instr2_stream.find(ins) == instr2_stream.end())
continue;
int stream = instr2_stream[ins];
ins->set_stream(stream);
if(last_stream != stream)
{
insert_instr.insert_stream(p_program, ins, stream);
last_stream = stream;
}
std::vector<int> events;
for(auto&& arg : ins->inputs())
{
if(instr2_stream.find(arg) == instr2_stream.end())
continue;
int arg_s = instr2_stream[arg];
if(arg_s == stream)
continue;
if(!has_mask(arg, record_event))
{
events.push_back(event);
arg->set_event(event);
arg->add_mask(record_event);
if(enable_event_as_instr)
insert_instr.insert_record_event(p_program, std::next(arg), event);
event++;
}
ins->add_mask(wait_event);
add_mask(arg, record_event);
add_mask(ins, wait_event);
}
if(enable_event_as_instr)
{
for(auto&& i : events)
insert_instr.insert_wait_event(p_program, ins, i);
}
}
}
void pre_scheduling_impl::run()
{
std::size_t num_of_instrs = p_program->size();
if(num_of_instrs == 0)
return;
MIGRAPHX_DEBUG(dump("---Before pre-scheduling---"));
MIGRAPHX_DEBUG(dump_program());
nodes.resize(num_of_instrs);
compute_weights();
reorder();
}
void pre_scheduling_impl::verify()
{
std::unordered_map<instruction_ref, bool> visited;
for(auto ins : iterator_for(*p_program))
{
for(auto&& arg : ins->inputs())
{
if(visited.find(arg) == visited.end())
MIGRAPHX_THROW("Input not visited");
}
visited[ins] = true;
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void pre_scheduling_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void pre_scheduling_impl::dump_program() { std::cout << *p_program << std::endl; }
void pre_scheduling_impl::dump(std::list<dag_node*>& sorted_nodes)
{
for(auto&& node : sorted_nodes)
{
node->dump();
if(!node->ins->inputs().empty())
{
std::cout << " inputs: ";
for(auto&& arg : node->ins->inputs())
{
dag_node* def_node = instr2_node[arg];
std::cout << " @" << def_node->ins_ndx;
}
std::cout << std::endl;
}
}
}
void dag_node::dump()
{
std::cout << " @" << ins_ndx;
std::cout << " name: " << ins->name();
std::cout << " weight: " << weight;
std::cout << " weight_sum: " << weight_sum;
if(can_use_stream())
std::cout << " stream: " << stream;
std::cout << " partition: " << partition;
std::cout << " sched_cycle: " << sched_cycle;
std::cout << std::endl;
}
#endif
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment