Commit 4332ccf6 authored by Paul's avatar Paul
Browse files

Refactor

parent 9b5e0c18
......@@ -17,13 +17,11 @@ add_library(migraphx
instruction.cpp
program.cpp
shape.cpp
schedule.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
opt/pre_scheduling.cpp
opt/pre_scheduling_impl.cpp
opt/dom_info.cpp
)
rocm_clang_tidy_check(migraphx)
rocm_install_targets(
......
......@@ -41,9 +41,9 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin or undefined
if(i->get_shape().elements() == 0 and not(i->name().front() == '@') and
not(i->name() == "undefined"))
// Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
continue;
assert(bidistance(p, i, last) > 0);
fix([&](auto self, auto leaf) {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
......@@ -210,6 +210,7 @@ inline const ValueType& any_cast(const context& x)
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#include <migraphx/common_header.hpp>
#include <migraphx/set_operator.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Compute dominators, post-dominators, dominator tree, post-dominator tree
// for instructions with streams. Also do program analysis to identify
// concurrent instructions in different streams.
struct dom_info
{
dom_info(program* p) : p_program(p)
{
instr2_idom.clear();
instr2_ipdom.clear();
}
void compute_dom(bool);
void
find_dom_tree(std::unordered_map<const instruction*, std::set<const instruction*>>& instr2_doms,
const instruction* p_ins,
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree,
std::unordered_map<const instruction*, const instruction*>& idom);
#ifdef MIGRAPHX_DEBUG_OPT
void dump_doms(std::unordered_map<const instruction*, int>&, bool);
#endif
bool is_split_point(instruction_ref ins);
bool is_merge_point(instruction_ref ins);
// whether ins1 strictly post-dominates ins2.
bool strictly_post_dominates(const instruction* ins1, const instruction* ins2);
// Program analysis to identify concurrent instructions.
void propagate_splits(
int num_of_streams,
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
program* p_program;
// map instruction to its immediate dominator.
std::unordered_map<const instruction*, const instruction*> instr2_idom;
// map instruction to its immediate post dominator.
std::unordered_map<const instruction*, const instruction*> instr2_ipdom;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_FIND_CONCUR_HPP
#define MIGRAPHX_GUARD_FIND_CONCUR_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <unordered_map>
#include <vector>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent analysis to find concurrent instructions
/// executing in different streams.
struct find_concur
{
void get_concur(program* p,
int num_of_streams,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
} const;
#else
/*
* Type-erased interface for:
*
* struct find_concur
* {
* void get_concur(program* p,int num_of_stream,std::unordered_map<const instruction*,
* std::vector<std::vector<const instruction*>>>& concur_instrs,std::unordered_map<const
* instruction*, int>& input) const;
* };
*
*/
struct find_concur
{
// Constructors
find_concur() = default;
template <typename PrivateDetailTypeErasedT>
find_concur(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
find_concur& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().get_concur(p, num_of_stream, concur_instrs, input);
}
friend bool is_shared(const find_concur& private_detail_x, const find_concur& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void
get_concur(program* p,
int num_of_stream,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& input) const override
{
private_detail_te_value.get_concur(p, num_of_stream, concur_instrs, input);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(find_concur* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(find_concur& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const find_concur& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -137,6 +137,18 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
template <class F, class Proj>
auto by(F f, Proj proj)
{
return [=](auto&&... xs) { return f(proj(std::forward<decltype(xs)>(xs))...); };
}
template <class T>
auto index_of(T& x)
{
return [&](auto&& y) { return x[y]; };
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -41,8 +41,6 @@ struct instruction
const operation& get_operator() const;
int get_stream() const;
void set_stream(int);
std::string name() const;
const std::vector<instruction_ref>& inputs() const;
......@@ -96,7 +94,6 @@ struct instruction
std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments;
literal lit;
int stream = -1;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -112,6 +109,7 @@ struct hash<migraphx::instruction_ref>
return std::hash<migraphx::instruction*>{}(&*x);
}
};
} // namespace std
#endif
......@@ -3,21 +3,64 @@
#include <cassert>
#include <type_traits>
#include <iterator>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct iterator_for_select
{
template <class T>
static T deref(T x)
{
return x;
}
template <class T>
static auto begin(T* x)
{
return x->begin();
}
template <class T>
static auto end(T* x)
{
return x->end();
}
};
struct iterator_for_select_reverse
{
template <class T>
static auto deref(T x)
{
return std::prev(x.base());
}
template <class T>
static auto begin(T* x)
{
return std::make_reverse_iterator(x->end());
}
template <class T>
static auto end(T* x)
{
return std::make_reverse_iterator(x->begin());
}
};
template <class T, class Selector = iterator_for_select>
struct iterator_for_range
{
T* base;
using base_iterator = std::remove_reference_t<decltype(base->begin())>;
using base_iterator = std::remove_reference_t<decltype(Selector::begin(base))>;
struct iterator
{
base_iterator i;
base_iterator operator*() const { return i; }
auto operator*() const { return Selector::deref(i); }
base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) const { return i != rhs.i; }
};
......@@ -25,12 +68,12 @@ struct iterator_for_range
iterator begin()
{
assert(base != nullptr);
return {base->begin()};
return {Selector::begin(base)};
}
iterator end()
{
assert(base != nullptr);
return {base->end()};
return {Selector::end(base)};
}
};
template <class T>
......@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return {&x};
}
template <class T>
iterator_for_range<T, iterator_for_select_reverse> reverse_iterator_for(T& x)
{
return {&x};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -4,8 +4,6 @@
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/find_concur.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,12 +15,11 @@ struct program;
struct memory_coloring
{
std::string allocation_op{};
int num_of_streams = 0;
find_concur f_concur;
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -1174,9 +1174,19 @@ struct load
}
argument compute(const shape&, const std::vector<argument>& args) const
{
if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset};
}
int output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
{
os << op.name() << "[";
os << "offset=" << op.offset << ",";
os << "end=" << (op.offset + op.s.bytes()) << "]";
return os;
}
};
struct outline
......
......@@ -9,7 +9,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PRE_SCHEDULING)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/insert_instruction.hpp>
namespace migraphx {
struct pre_scheduling
{
std::function<std::pair<int, int>(const operation&)> weight_func;
int num_of_streams;
insert_instruction insert_instr;
bool verify = false;
std::string name() const { return "pre scheduling"; }
void apply(program& p) const;
};
} // namespace migraphx
#endif
......@@ -9,6 +9,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
......@@ -16,6 +17,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl;
const operation& get_operation(instruction_ref ins);
......@@ -98,7 +102,7 @@ struct program
void compile(const target& t, tracer trace = tracer{});
void finalize();
void finish();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print() const;
......@@ -107,6 +111,8 @@ struct program
void dry_run(parameter_map params) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......@@ -114,6 +120,7 @@ struct program
private:
std::unique_ptr<program_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/schedule_model.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Schedule instructions for concurrent execution
*/
struct schedule
{
schedule_model model{};
std::string name() const { return "schedule"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#define MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#ifndef MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#define MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#include <cassert>
#include <string>
......@@ -8,24 +8,31 @@
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct operation;
#ifdef DOXYGEN
/// An interface for target-dependent instruction insertion.
/// for multi-stream execution.
struct insert_instruction
/// An interface for target-dependent model for the scheduler
struct schedule_model
{
void insert_create_events(program* p, instruction_ref ins, int num_of_events);
void insert_record_event(program* p, instruction_ref ins, int event);
void insert_wait_event(program* p, instruction_ref ins, int event);
void insert_stream(program* p, instruction_ref ins, int stream);
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
#else
......@@ -33,23 +40,24 @@ struct insert_instruction
/*
* Type-erased interface for:
*
* struct insert_instruction
* struct schedule_model
* {
* void insert_create_events(program* p,instruction_ref ins,int input) ;
* void insert_record_event(program* p,instruction_ref ins,int input) ;
* void insert_wait_event(program* p,instruction_ref ins,int input) ;
* void insert_stream(program* p,instruction_ref ins,int input) ;
* std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
struct insert_instruction
struct schedule_model
{
// Constructors
insert_instruction() = default;
schedule_model() = default;
template <typename PrivateDetailTypeErasedT>
insert_instruction(PrivateDetailTypeErasedT value)
schedule_model(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
......@@ -59,7 +67,7 @@ struct insert_instruction
// Assignment
template <typename PrivateDetailTypeErasedT>
insert_instruction& operator=(PrivateDetailTypeErasedT value)
schedule_model& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
......@@ -100,32 +108,38 @@ struct insert_instruction
return private_detail_te_get_handle().type();
}
void insert_create_events(program* p, instruction_ref ins, int input)
std::size_t concurrency() const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_create_events(p, ins, input);
return (*this).private_detail_te_get_handle().concurrency();
}
void insert_record_event(program* p, instruction_ref ins, int input)
void sched(program& p, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_record_event(p, ins, input);
(*this).private_detail_te_get_handle().sched(p, ins, n);
}
void insert_wait_event(program* p, instruction_ref ins, int input)
void wait(program& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_wait_event(p, ins, input);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
}
void insert_stream(program* p, instruction_ref ins, int input)
void record(program& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().insert_stream(p, ins, input);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
}
friend bool is_shared(const insert_instruction& private_detail_x,
const insert_instruction& private_detail_y)
std::size_t weight(const operation& op) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().weight(op);
}
friend bool is_shared(const schedule_model& private_detail_x,
const schedule_model& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
......@@ -138,10 +152,11 @@ struct insert_instruction
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void insert_create_events(program* p, instruction_ref ins, int input) = 0;
virtual void insert_record_event(program* p, instruction_ref ins, int input) = 0;
virtual void insert_wait_event(program* p, instruction_ref ins, int input) = 0;
virtual void insert_stream(program* p, instruction_ref ins, int input) = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -172,28 +187,30 @@ struct insert_instruction
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void insert_create_events(program* p, instruction_ref ins, int input) override
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.insert_create_events(p, ins, input);
private_detail_te_value.sched(p, ins, n);
}
void insert_record_event(program* p, instruction_ref ins, int input) override
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.insert_record_event(p, ins, input);
private_detail_te_value.wait(p, ins, wait_id);
}
void insert_wait_event(program* p, instruction_ref ins, int input) override
void record(program& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.insert_wait_event(p, ins, input);
private_detail_te_value.record(p, ins, wait_id);
}
void insert_stream(program* p, instruction_ref ins, int input) override
std::size_t weight(const operation& op) const override
{
private_detail_te_value.insert_stream(p, ins, input);
return private_detail_te_value.weight(op);
}
PrivateDetailTypeErasedT private_detail_te_value;
......@@ -232,19 +249,19 @@ struct insert_instruction
};
template <typename ValueType>
inline const ValueType* any_cast(const insert_instruction* x)
inline const ValueType* any_cast(const schedule_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(insert_instruction* x)
inline ValueType* any_cast(schedule_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(insert_instruction& x)
inline ValueType& any_cast(schedule_model& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
......@@ -253,7 +270,7 @@ inline ValueType& any_cast(insert_instruction& x)
}
template <typename ValueType>
inline const ValueType& any_cast(const insert_instruction& x)
inline const ValueType& any_cast(const schedule_model& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
......@@ -262,6 +279,7 @@ inline const ValueType& any_cast(const insert_instruction& x)
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_SET_OPERATOR_IMPL_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_intersection(const Set& lhs, const Set& rhs)
{
if(lhs.size() <= rhs.size())
{
Set iset;
for(const Key& key : lhs)
{
if(rhs.count(key) > 0)
{
iset.insert(key);
}
}
return std::move(iset);
}
else
{
return set_intersection(rhs, lhs);
}
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_union(const Set& lhs, const Set& rhs)
{
Set uset{lhs};
uset.insert(rhs.begin(), rhs.end());
return std::move(uset);
}
template <typename Set, typename Key = typename Set::value_type>
static inline Set set_difference(const Set& lhs, const Set& rhs)
{
Set dset{lhs};
for(auto& iter : rhs)
{
dset.erase(iter);
}
return std::move(dset);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -36,6 +36,8 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x) { os << x; }
template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void())
......@@ -53,7 +55,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
template <class T>
void stream_write_value(std::ostream& os, const T& x)
{
detail::stream_write_value_impl(rank<1>{}, os, x);
detail::stream_write_value_impl(rank<2>{}, os, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -22,10 +22,8 @@ struct target
{
/// A unique name used to identify the target
std::string name() const;
/// The transformation passes to be run
/**
* @brief The transformation pass to be run during compilation.
* @details [long description]
*
* @param ctx This is the target-dependent context that is created by `get_context`
* @return The passes to be ran
......
......@@ -87,9 +87,6 @@ const literal& instruction::get_literal() const
return lit;
}
int instruction::get_stream() const { return stream; }
void instruction::set_stream(int s) { stream = s; }
const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); }
......@@ -214,5 +211,6 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
{
return op.compute_shape(to_shapes(args));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment