Commit 4332ccf6 authored by Paul's avatar Paul
Browse files

Refactor

parent 9b5e0c18
...@@ -17,13 +17,11 @@ add_library(migraphx ...@@ -17,13 +17,11 @@ add_library(migraphx
instruction.cpp instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
schedule.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
opt/pre_scheduling.cpp
opt/pre_scheduling_impl.cpp
opt/dom_info.cpp
) )
rocm_clang_tidy_check(migraphx) rocm_clang_tidy_check(migraphx)
rocm_install_targets( rocm_install_targets(
......
...@@ -41,9 +41,9 @@ void dead_code_elimination::apply(program& p) const ...@@ -41,9 +41,9 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin or undefined // Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and not(i->name().front() == '@') and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
not(i->name() == "undefined")) i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(p, i, last) > 0); assert(bidistance(p, i, last) > 0);
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
...@@ -210,6 +210,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -210,6 +210,7 @@ inline const ValueType& any_cast(const context& x)
} }
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#include <migraphx/common_header.hpp>
#include <migraphx/set_operator.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Compute dominators, post-dominators, dominator tree, post-dominator tree
// for instructions with streams. Also do program analysis to identify
// concurrent instructions in different streams.
struct dom_info
{
dom_info(program* p) : p_program(p)
{
instr2_idom.clear();
instr2_ipdom.clear();
}
void compute_dom(bool);
void
find_dom_tree(std::unordered_map<const instruction*, std::set<const instruction*>>& instr2_doms,
const instruction* p_ins,
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree,
std::unordered_map<const instruction*, const instruction*>& idom);
#ifdef MIGRAPHX_DEBUG_OPT
void dump_doms(std::unordered_map<const instruction*, int>&, bool);
#endif
bool is_split_point(instruction_ref ins);
bool is_merge_point(instruction_ref ins);
// whether ins1 strictly post-dominates ins2.
bool strictly_post_dominates(const instruction* ins1, const instruction* ins2);
// Program analysis to identify concurrent instructions.
void propagate_splits(
int num_of_streams,
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
program* p_program;
// map instruction to its immediate dominator.
std::unordered_map<const instruction*, const instruction*> instr2_idom;
// map instruction to its immediate post dominator.
std::unordered_map<const instruction*, const instruction*> instr2_ipdom;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_FIND_CONCUR_HPP
#define MIGRAPHX_GUARD_FIND_CONCUR_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <unordered_map>
#include <vector>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent analysis to find concurrent instructions
/// executing in different streams.
struct find_concur
{
void get_concur(program* p,
int num_of_streams,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
} const;
#else
/*
* Type-erased interface for:
*
* struct find_concur
* {
* void get_concur(program* p,int num_of_stream,std::unordered_map<const instruction*,
* std::vector<std::vector<const instruction*>>>& concur_instrs,std::unordered_map<const
* instruction*, int>& input) const;
* };
*
*/
struct find_concur
{
// Constructors
find_concur() = default;
template <typename PrivateDetailTypeErasedT>
find_concur(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
find_concur& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().get_concur(p, num_of_stream, concur_instrs, input);
}
friend bool is_shared(const find_concur& private_detail_x, const find_concur& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const override
{
private_detail_te_value.get_concur(p, num_of_stream, concur_instrs, input);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(find_concur& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const find_concur& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -137,6 +137,18 @@ auto fold(F f) ...@@ -137,6 +137,18 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
template <class F, class Proj>
auto by(F f, Proj proj)
{
return [=](auto&&... xs) { return f(proj(std::forward<decltype(xs)>(xs))...); };
}
template <class T>
auto index_of(T& x)
{
return [&](auto&& y) { return x[y]; };
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -41,8 +41,6 @@ struct instruction ...@@ -41,8 +41,6 @@ struct instruction
const operation& get_operator() const; const operation& get_operator() const;
int get_stream() const;
void set_stream(int);
std::string name() const; std::string name() const;
const std::vector<instruction_ref>& inputs() const; const std::vector<instruction_ref>& inputs() const;
...@@ -96,7 +94,6 @@ struct instruction ...@@ -96,7 +94,6 @@ struct instruction
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
int stream = -1;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -112,6 +109,7 @@ struct hash<migraphx::instruction_ref> ...@@ -112,6 +109,7 @@ struct hash<migraphx::instruction_ref>
return std::hash<migraphx::instruction*>{}(&*x); return std::hash<migraphx::instruction*>{}(&*x);
} }
}; };
} // namespace std } // namespace std
#endif #endif
...@@ -3,21 +3,64 @@ ...@@ -3,21 +3,64 @@
#include <cassert> #include <cassert>
#include <type_traits> #include <type_traits>
#include <iterator>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> struct iterator_for_select
{
template <class T>
static T deref(T x)
{
return x;
}
template <class T>
static auto begin(T* x)
{
return x->begin();
}
template <class T>
static auto end(T* x)
{
return x->end();
}
};
struct iterator_for_select_reverse
{
template <class T>
static auto deref(T x)
{
return std::prev(x.base());
}
template <class T>
static auto begin(T* x)
{
return std::make_reverse_iterator(x->end());
}
template <class T>
static auto end(T* x)
{
return std::make_reverse_iterator(x->begin());
}
};
template <class T, class Selector = iterator_for_select>
struct iterator_for_range struct iterator_for_range
{ {
T* base; T* base;
using base_iterator = std::remove_reference_t<decltype(base->begin())>; using base_iterator = std::remove_reference_t<decltype(Selector::begin(base))>;
struct iterator struct iterator
{ {
base_iterator i; base_iterator i;
base_iterator operator*() const { return i; } auto operator*() const { return Selector::deref(i); }
base_iterator operator++() { return ++i; } base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) const { return i != rhs.i; } bool operator!=(const iterator& rhs) const { return i != rhs.i; }
}; };
...@@ -25,12 +68,12 @@ struct iterator_for_range ...@@ -25,12 +68,12 @@ struct iterator_for_range
iterator begin() iterator begin()
{ {
assert(base != nullptr); assert(base != nullptr);
return {base->begin()}; return {Selector::begin(base)};
} }
iterator end() iterator end()
{ {
assert(base != nullptr); assert(base != nullptr);
return {base->end()}; return {Selector::end(base)};
} }
}; };
template <class T> template <class T>
...@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x) ...@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return {&x}; return {&x};
} }
template <class T>
iterator_for_range<T, iterator_for_select_reverse> reverse_iterator_for(T& x)
{
return {&x};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/find_concur.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,12 +15,11 @@ struct program; ...@@ -17,12 +15,11 @@ struct program;
struct memory_coloring struct memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
int num_of_streams = 0;
find_concur f_concur;
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -1174,9 +1174,19 @@ struct load ...@@ -1174,9 +1174,19 @@ struct load
} }
argument compute(const shape&, const std::vector<argument>& args) const argument compute(const shape&, const std::vector<argument>& args) const
{ {
if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
{
os << op.name() << "[";
os << "offset=" << op.offset << ",";
os << "end=" << (op.offset + op.s.bytes()) << "]";
return os;
}
}; };
struct outline struct outline
......
...@@ -9,7 +9,6 @@ namespace migraphx { ...@@ -9,7 +9,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PRE_SCHEDULING)
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/insert_instruction.hpp>
namespace migraphx {
struct pre_scheduling
{
std::function<std::pair<int, int>(const operation&)> weight_func;
int num_of_streams;
insert_instruction insert_instr;
bool verify = false;
std::string name() const { return "pre scheduling"; }
void apply(program& p) const;
};
} // namespace migraphx
#endif
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/tracer.hpp> #include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -16,6 +17,9 @@ ...@@ -16,6 +17,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl; struct program_impl;
const operation& get_operation(instruction_ref ins); const operation& get_operation(instruction_ref ins);
...@@ -98,7 +102,7 @@ struct program ...@@ -98,7 +102,7 @@ struct program
void compile(const target& t, tracer trace = tracer{}); void compile(const target& t, tracer trace = tracer{});
void finalize(); void finalize();
void finish();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print() const; void debug_print() const;
...@@ -107,6 +111,8 @@ struct program ...@@ -107,6 +111,8 @@ struct program
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
...@@ -114,6 +120,7 @@ struct program ...@@ -114,6 +120,7 @@ struct program
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/schedule_model.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Schedule instructions for concurrent execution
*/
struct schedule
{
schedule_model model{};
std::string name() const { return "schedule"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP #ifndef MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#define MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP #define MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -8,24 +8,31 @@ ...@@ -8,24 +8,31 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
struct operation;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for target-dependent instruction insertion. /// An interface for target-dependent model for the scheduler
/// for multi-stream execution. struct schedule_model
struct insert_instruction
{ {
void insert_create_events(program* p, instruction_ref ins, int num_of_events); /// Get the number of concurrent instruction allowed
void insert_record_event(program* p, instruction_ref ins, int event); std::size_t concurrency() const;
void insert_wait_event(program* p, instruction_ref ins, int event); /// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
void insert_stream(program* p, instruction_ref ins, int stream); // Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
}; };
#else #else
...@@ -33,23 +40,24 @@ struct insert_instruction ...@@ -33,23 +40,24 @@ struct insert_instruction
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
* struct insert_instruction * struct schedule_model
* { * {
* void insert_create_events(program* p,instruction_ref ins,int input) ; * std::size_t concurrency() const;
* void insert_record_event(program* p,instruction_ref ins,int input) ; * void sched(program& p,instruction_ref ins,std::size_t n) const;
* void insert_wait_event(program* p,instruction_ref ins,int input) ; * void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void insert_stream(program* p,instruction_ref ins,int input) ; * void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* }; * };
* *
*/ */
struct insert_instruction struct schedule_model
{ {
// Constructors // Constructors
insert_instruction() = default; schedule_model() = default;
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
insert_instruction(PrivateDetailTypeErasedT value) schedule_model(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var( : private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type< std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>( typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
...@@ -59,7 +67,7 @@ struct insert_instruction ...@@ -59,7 +67,7 @@ struct insert_instruction
// Assignment // Assignment
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
insert_instruction& operator=(PrivateDetailTypeErasedT value) schedule_model& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); *private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
...@@ -100,32 +108,38 @@ struct insert_instruction ...@@ -100,32 +108,38 @@ struct insert_instruction
return private_detail_te_get_handle().type(); return private_detail_te_get_handle().type();
} }
void insert_create_events(program* p, instruction_ref ins, int input) std::size_t concurrency() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_create_events(p, ins, input); return (*this).private_detail_te_get_handle().concurrency();
} }
void insert_record_event(program* p, instruction_ref ins, int input) void sched(program& p, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_record_event(p, ins, input); (*this).private_detail_te_get_handle().sched(p, ins, n);
} }
void insert_wait_event(program* p, instruction_ref ins, int input) void wait(program& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_wait_event(p, ins, input); (*this).private_detail_te_get_handle().wait(p, ins, wait_id);
} }
void insert_stream(program* p, instruction_ref ins, int input) void record(program& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_stream(p, ins, input); (*this).private_detail_te_get_handle().record(p, ins, wait_id);
} }
friend bool is_shared(const insert_instruction& private_detail_x, std::size_t weight(const operation& op) const
const insert_instruction& private_detail_y) {
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().weight(op);
}
friend bool is_shared(const schedule_model& private_detail_x,
const schedule_model& private_detail_y)
{ {
return private_detail_x.private_detail_te_handle_mem_var == return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var; private_detail_y.private_detail_te_handle_mem_var;
...@@ -138,10 +152,11 @@ struct insert_instruction ...@@ -138,10 +152,11 @@ struct insert_instruction
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void insert_create_events(program* p, instruction_ref ins, int input) = 0; virtual std::size_t concurrency() const = 0;
virtual void insert_record_event(program* p, instruction_ref ins, int input) = 0; virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0;
virtual void insert_wait_event(program* p, instruction_ref ins, int input) = 0; virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void insert_stream(program* p, instruction_ref ins, int input) = 0; virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -172,28 +187,30 @@ struct insert_instruction ...@@ -172,28 +187,30 @@ struct insert_instruction
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void insert_create_events(program* p, instruction_ref ins, int input) override std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.insert_create_events(p, ins, input); private_detail_te_value.sched(p, ins, n);
} }
void insert_record_event(program* p, instruction_ref ins, int input) override void wait(program& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.insert_record_event(p, ins, input); private_detail_te_value.wait(p, ins, wait_id);
} }
void insert_wait_event(program* p, instruction_ref ins, int input) override void record(program& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.insert_wait_event(p, ins, input); private_detail_te_value.record(p, ins, wait_id);
} }
void insert_stream(program* p, instruction_ref ins, int input) override std::size_t weight(const operation& op) const override
{ {
private_detail_te_value.insert_stream(p, ins, input); return private_detail_te_value.weight(op);
} }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
...@@ -232,19 +249,19 @@ struct insert_instruction ...@@ -232,19 +249,19 @@ struct insert_instruction
}; };
template <typename ValueType> template <typename ValueType>
inline const ValueType* any_cast(const insert_instruction* x) inline const ValueType* any_cast(const schedule_model* x)
{ {
return x->any_cast<ValueType>(); return x->any_cast<ValueType>();
} }
template <typename ValueType> template <typename ValueType>
inline ValueType* any_cast(insert_instruction* x) inline ValueType* any_cast(schedule_model* x)
{ {
return x->any_cast<ValueType>(); return x->any_cast<ValueType>();
} }
template <typename ValueType> template <typename ValueType>
inline ValueType& any_cast(insert_instruction& x) inline ValueType& any_cast(schedule_model& x)
{ {
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>(); auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr) if(y == nullptr)
...@@ -253,7 +270,7 @@ inline ValueType& any_cast(insert_instruction& x) ...@@ -253,7 +270,7 @@ inline ValueType& any_cast(insert_instruction& x)
} }
template <typename ValueType> template <typename ValueType>
inline const ValueType& any_cast(const insert_instruction& x) inline const ValueType& any_cast(const schedule_model& x)
{ {
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>(); const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr) if(y == nullptr)
...@@ -262,6 +279,7 @@ inline const ValueType& any_cast(const insert_instruction& x) ...@@ -262,6 +279,7 @@ inline const ValueType& any_cast(const insert_instruction& x)
} }
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_intersection(const Set& lhs, const Set& rhs)
{
if(lhs.size() <= rhs.size())
{
Set iset;
for(const Key& key : lhs)
{
if(rhs.count(key) > 0)
{
iset.insert(key);
}
}
return std::move(iset);
}
else
{
return set_intersection(rhs, lhs);
}
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_union(const Set& lhs, const Set& rhs)
{
Set uset{lhs};
uset.insert(rhs.begin(), rhs.end());
return std::move(uset);
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_difference(const Set& lhs, const Set& rhs)
{
Set dset{lhs};
for(auto& iter : rhs)
{
dset.erase(iter);
}
return std::move(dset);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -36,6 +36,8 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -36,6 +36,8 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail { namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x) { os << x; }
template <class Range> template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void()) -> decltype(r.begin(), r.end(), void())
...@@ -53,7 +55,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) ...@@ -53,7 +55,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<2>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -22,10 +22,8 @@ struct target ...@@ -22,10 +22,8 @@ struct target
{ {
/// A unique name used to identify the target /// A unique name used to identify the target
std::string name() const; std::string name() const;
/// The transformation passes to be run
/** /**
* @brief The transformation pass to be run during compilation. * @brief The transformation pass to be run during compilation.
* @details [long description]
* *
* @param ctx This is the target-dependent context that is created by `get_context` * @param ctx This is the target-dependent context that is created by `get_context`
* @return The passes to be ran * @return The passes to be ran
......
...@@ -87,9 +87,6 @@ const literal& instruction::get_literal() const ...@@ -87,9 +87,6 @@ const literal& instruction::get_literal() const
return lit; return lit;
} }
int instruction::get_stream() const { return stream; }
void instruction::set_stream(int s) { stream = s; }
const operation& instruction::get_operator() const { return op; } const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); } std::string instruction::name() const { return op.name(); }
...@@ -214,5 +211,6 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg ...@@ -214,5 +211,6 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
{ {
return op.compute_shape(to_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment