Commit 4332ccf6 authored by Paul's avatar Paul
Browse files

Refactor

parent 9b5e0c18
......@@ -18,10 +18,8 @@
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/pre_scheduling.hpp>
#include <migraphx/gpu/machine_model.hpp>
#include <migraphx/gpu/find_concur_gpu.hpp>
#include <migraphx/gpu/insert_instruction_gpu.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/schedule.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -29,9 +27,7 @@ namespace gpu {
std::vector<pass> target::get_passes(migraphx::context& gctx) const
{
auto& ctx = any_cast<context>(gctx);
std::function<std::pair<int, int>(const operation&)> weight_func = op_info();
int num_of_streams = ctx.get_current_device().nstreams();
auto& ctx = any_cast<context>(gctx);
// clang-format off
return
{
......@@ -56,9 +52,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
pre_scheduling{weight_func, num_of_streams, insert_instruction_gpu{}},
memory_coloring{"hip::allocate", num_of_streams, find_concur_gpu{}},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}},
memory_coloring{"hip::allocate"},
dead_code_elimination{},
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{},
......
......@@ -130,9 +130,7 @@ migraphx::argument run_gpu(migraphx::program& p)
EXPECT(is_shared(ctx, p.get_context()));
p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context()));
auto eval = p.eval(m);
auto ret_val = migraphx::gpu::from_gpu(eval);
return ret_val;
return migraphx::gpu::from_gpu(p.eval(m));
}
template <class V>
......
#include <test.hpp>
#include <basic_ops.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/verify_args.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto in1 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {32, 64, 1, 1}});
auto in2 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {64, 64, 1, 1}});
auto p1 = p.add_instruction(migraphx::op::convolution{}, in1, in2);
auto in3 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {64, 64, 1, 1}});
auto p2 = p.add_instruction(migraphx::op::convolution{}, in1, in3);
p.add_instruction(migraphx::op::concat{1}, p1, p2);
return p;
}
migraphx::argument run_gpu()
{
setenv("MIGRAPHX_DISABLE_NULL_STREAM", "1", 1);
migraphx::program p = create_program();
p.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
}
auto ret_val = migraphx::gpu::from_gpu(p.eval(m));
p.finish();
return ret_val;
}
migraphx::argument run_cpu()
{
migraphx::program p = create_program();
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::generate_argument(x.second);
}
return p.eval(m);
}
void gpu_stream_execution_test()
{
auto result1 = run_gpu();
auto result2 = run_cpu();
verify_args("test", result2, result1);
}
int main() { gpu_stream_execution_test(); }
......@@ -36,6 +36,18 @@ inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
return s;
}
template <class T>
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
{
s << "{ ";
for(auto&& x : v)
{
s << x << ", ";
}
s << "}";
return s;
}
template <class T, class U, class Operator>
struct expression
{
......
......@@ -2,47 +2,15 @@
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/dom_info.hpp>
#include <migraphx/common_header.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct set_stream
{
int stream = -1;
std::string name() const { return "gpu::set_stream"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
if(inputs.empty())
return {};
else
return inputs.front();
}
};
struct find_concur
{
void get_concur(
migraphx::program* p,
int num_of_streams,
std::unordered_map<const migraphx::instruction*,
std::vector<std::vector<const migraphx::instruction*>>>& concur_instrs,
std::unordered_map<const migraphx::instruction*, int>& instr2_points) const
{
migraphx::dom_info info(p);
info.compute_dom(true);
info.propagate_splits(num_of_streams, concur_instrs, instr2_points);
}
};
struct memory_coloring_target
{
std::string name() const { return "memory_coloring"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::memory_coloring{"allocate", 4, find_concur{}, true}};
return {migraphx::memory_coloring{"allocate", true}};
}
migraphx::context get_context() const { return {}; }
};
......@@ -639,40 +607,4 @@ TEST_CASE(literal_test)
CHECK(lit == result);
}
TEST_CASE(concurrent_test)
{
migraphx::program p;
auto in = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {40}});
auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a1, in);
p.insert_instruction(p1, set_stream{0});
p1->set_stream(0);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p2->set_stream(0);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p4 = p.add_instruction(pass_op{}, a4, p2);
p4->set_stream(0);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a3, p1);
p3->set_stream(1);
p.insert_instruction(p3, set_stream{1});
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p5 = p.add_instruction(pass_op{}, a5, p3);
p5->set_stream(1);
auto a6 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p6 = p.add_instruction(pass_op{}, a6, p1);
p6->set_stream(2);
p.insert_instruction(p6, set_stream{2});
auto a7 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p7 = p.add_instruction(pass_op{}, a7, p6);
p7->set_stream(2);
auto a8 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p8 = p.add_instruction(migraphx::op::concat{0}, a8, p4, p5, p7);
p8->set_stream(0);
p.insert_instruction(p8, set_stream{0});
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 960);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/schedule.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct unary_op
{
std::string name() const { return "unary"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct nary_op
{
std::string comment = "";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.comment, "comment"));
}
std::string name() const { return "nary"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct stream_free_op
{
std::string comment = "";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.comment, "comment"));
}
std::string name() const { return "stream_free"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct wait_event
{
std::shared_ptr<std::vector<std::size_t>> wait_for =
std::make_shared<std::vector<std::size_t>>();
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(*self.wait_for, "wait_for"));
}
std::string name() const { return "wait_event"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{
assert(wait_for != nullptr);
assert(not wait_for->empty());
return {};
}
};
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
using int_map = std::unordered_map<std::size_t, std::size_t>;
using wait_map =
std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
struct schedule_model_test
{
std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
std::shared_ptr<int_map> wait2stream = std::make_shared<int_map>();
std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
std::size_t concurrency() const { return 4; }
void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
{
(*ins2stream)[ins] = n;
}
void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{
if(ins2wait_for->count(ins) == 0)
{
auto event = wait_event{};
p.insert_instruction(ins, event);
(*ins2wait_for)[ins] = event.wait_for;
}
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
}
void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
{
(*wait2stream)[wait_id] = ins2stream->at(ins);
}
std::size_t weight(const migraphx::operation& op) const
{
if(op.name() == "stream_free")
return 0;
else if(op.name() == "binary" or op.name() == "unary")
return 4;
else
return 1;
}
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
}
return false;
}
struct schedule_target
{
schedule_model_test model{};
std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::schedule{model}};
}
migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
std::vector<std::size_t> get_streams(std::vector<migraphx::instruction_ref> inss)
{
std::vector<std::size_t> result;
std::transform(inss.begin(), inss.end(), std::back_inserter(result), [&](auto ins) {
return this->get_stream(ins);
});
return result;
}
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
void check_conflicts(migraphx::program& p,
std::vector<std::vector<migraphx::instruction_ref>> conflicts,
bool result = true)
{
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : conflicts[i])
{
for(auto ins2 : conflicts[j])
{
// If both instructions are on the same stream then dont check for a conflict
if(this->has_stream(ins1) and this->has_stream(ins2) and
this->get_stream(ins1) == this->get_stream(ins2))
continue;
CHECK(::check_conflicts(p, ins1, ins2) == result);
}
}
});
}
};
template <class T>
std::vector<T> sorted(std::vector<T> x)
{
std::sort(x.begin(), x.end());
return x;
}
template <class T>
std::vector<T> unique(std::vector<T> x)
{
std::sort(x.begin(), x.end());
x.erase(std::unique(x.begin(), x.end()), x.end());
return x;
}
std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
return unique(std::move(wait_for));
}
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
return unique(wait_for);
}
std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
auto wait_ins = std::prev(ins);
// Skip identity operators
while(wait_ins->name() == "identity")
wait_ins = std::prev(wait_ins);
if(wait_ins->name() != "wait_event")
return {};
auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
std::sort(wf.begin(), wf.end());
return wf;
}
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
{
std::vector<migraphx::instruction_ref> result;
for(std::size_t i = 0; i < n; i++)
{
result.push_back(p.add_instruction(x, input));
input = result.back();
}
return result;
}
TEST_CASE(single_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(stream_free)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(stream_free_op{}, one);
auto onep2 = p.add_instruction(stream_free_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
}
TEST_CASE(zero_record)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1);
auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2);
auto binary = p.add_instruction(nary_op{}, onei1, onei2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.has_stream(binary));
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
t.check_conflicts(p, {{onep1, onei1}, {onep2, onei2}});
}
TEST_CASE(zero_merge1)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge2)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge3)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(id));
// There is no wait
EXPECT(get_wait_for(id).empty());
// Stream assignment for final op
EXPECT(t.get_stream(final) == 0);
EXPECT(get_wait_for(final) ==
get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge4)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(id));
// There is no wait
EXPECT(get_wait_for(id).empty());
// Stream assignment for final op
EXPECT(t.get_stream(final) == 0);
EXPECT(get_wait_for(final) ==
get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(double_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_instruction(stream_free_op{}, p.add_literal(1));
auto two = p.add_instruction(stream_free_op{}, p.add_literal(2));
auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
t.check_conflicts(p, {{onep, one}, {twop, two}});
}
TEST_CASE(two_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(four_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one);
auto c2 = chain(p, 3, unary_op{}, one);
auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(t.get_stream(ins) == 2);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()),
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(five_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one);
auto c2 = chain(p, 4, unary_op{}, one);
auto c3 = chain(p, 3, unary_op{}, one);
auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(t.get_stream(ins) == 2);
for(auto ins : c4)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()),
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
t.check_conflicts(p, {c1, c2, c3, c4});
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(four_branches_eq)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(
sorted<std::size_t>(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
unique<std::size_t>(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
EXPECT(t.get_stream(binary) == 0);
EXPECT(
get_wait_for(binary) ==
get_wait_for(
t.get_stream(binary),
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
t.check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}
TEST_CASE(seq_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto c2 = chain(p, 2, unary_op{}, binary1);
auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == t.get_stream(c1.back()));
EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(c2.back()));
EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
}
TEST_CASE(par_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto start2 = p.add_instruction(unary_op{}, one);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_par_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto start2 = p.add_instruction(unary_op{}, one);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto outer1 = p.add_instruction(unary_op{}, one);
auto outer2 = p.add_instruction(unary_op{}, one);
auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output),
{t.get_stream(binary1),
t.get_stream(binary2),
t.get_stream(outer1),
t.get_stream(outer2)}));
EXPECT(t.get_stream(outer1) == 1);
EXPECT(t.get_stream(outer2) == 2);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}});
}
TEST_CASE(par_merge_multi_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto two = p.add_literal(1);
auto start2 = p.add_instruction(unary_op{}, two);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_split1)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = p.add_instruction(unary_op{}, c1);
auto s2 = p.add_instruction(unary_op{}, c1);
auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1) != t.get_stream(s2));
EXPECT(t.get_stream(output) == 0);
EXPECT(
get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
}
TEST_CASE(inner_split2)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = chain(p, 3, unary_op{}, c1.back());
auto s2 = chain(p, 4, unary_op{}, c1.back());
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back()));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) ==
get_wait_for(t.get_stream(output),
{t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())}));
EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())}));
t.check_conflicts(p, {c1, {i1}, s1, s2});
}
TEST_CASE(inception_resnet)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto input = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 2, unary_op{}, input);
auto i1 = p.add_instruction(unary_op{}, input);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output).empty());
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(inception1)
{
schedule_target t{};
migraphx::program p;
auto i1 = p.add_literal(0);
auto i2 = p.add_literal(1);
auto i3 = p.add_literal(1);
auto i4 = p.add_literal(2);
auto i7 = p.add_instruction(nary_op{"i7"}, i1, i4, i3, i2);
auto i8 = p.add_literal(2);
auto i9 = p.add_instruction(migraphx::op::identity{}, i8);
auto i10 = p.add_literal(1);
auto i11 = p.add_instruction(nary_op{"i11"}, i7, i9, i10);
auto i12 = p.add_literal(2);
auto i13 = p.add_instruction(migraphx::op::identity{}, i12);
auto i14 = p.add_literal(1);
auto i15 = p.add_literal(1);
auto i16 = p.add_literal(2);
auto i17 = p.add_instruction(nary_op{"i17"}, i11, i16, i15, i13, i14);
auto i18 = p.add_literal(2);
auto i19 = p.add_instruction(migraphx::op::identity{}, i18);
auto i20 = p.add_literal(1);
auto i21 = p.add_literal(1);
auto i22 = p.add_literal(2);
auto i23 = p.add_instruction(nary_op{"i23"}, i17, i22, i21, i19, i20);
auto i24 = p.add_literal(1);
auto i25 = p.add_instruction(nary_op{"i25"}, i23, i24);
auto i26 = p.add_literal(2);
auto i27 = p.add_instruction(migraphx::op::identity{}, i26);
auto i28 = p.add_literal(1);
auto i29 = p.add_literal(1);
auto i30 = p.add_literal(2);
auto i31 = p.add_instruction(nary_op{"i31"}, i25, i30, i29, i27, i28);
auto i32 = p.add_literal(2);
auto i33 = p.add_instruction(migraphx::op::identity{}, i32);
auto i34 = p.add_literal(1);
auto i35 = p.add_literal(1);
auto i36 = p.add_literal(2);
auto i37 = p.add_instruction(nary_op{"i37"}, i31, i36, i35, i33, i34);
auto i38 = p.add_literal(1);
auto i39 = p.add_instruction(nary_op{"i39"}, i37, i38);
auto i41 = p.add_literal(2);
auto i42 = p.add_instruction(migraphx::op::identity{}, i41);
auto i43 = p.add_literal(1);
auto i44 = p.add_literal(1);
auto i45 = p.add_literal(2);
auto i48 = p.add_instruction(nary_op{"i48"}, i39, i45, i44, i42, i43);
auto i49 = p.add_literal(2);
auto i50 = p.add_instruction(migraphx::op::identity{}, i49);
auto i51 = p.add_literal(1);
auto i52 = p.add_literal(1);
auto i53 = p.add_literal(2);
auto i54 = p.add_instruction(nary_op{"i54"}, i48, i53, i52, i50, i51);
auto i55 = p.add_literal(1);
auto i56 = p.add_instruction(migraphx::op::identity{}, i55);
auto i57 = p.add_literal(2);
auto i58 = p.add_instruction(migraphx::op::identity{}, i57);
auto i59 = p.add_literal(1);
auto i60 = p.add_literal(2);
auto i61 = p.add_instruction(nary_op{"i61"}, i54, i60, i59, i58, i56);
auto i62 = p.add_literal(2);
auto i63 = p.add_instruction(migraphx::op::identity{}, i62);
auto i64 = p.add_literal(1);
auto i65 = p.add_literal(1);
auto i66 = p.add_literal(2);
auto i69 = p.add_instruction(nary_op{"i69"}, i39, i66, i65, i63, i64);
auto i70 = p.add_instruction(migraphx::op::identity{}, i55);
auto i71 = p.add_literal(2);
auto i72 = p.add_instruction(migraphx::op::identity{}, i71);
auto i73 = p.add_literal(1);
auto i74 = p.add_literal(2);
auto i75 = p.add_instruction(nary_op{"i75"}, i69, i74, i73, i72, i70);
auto i77 = p.add_literal(1);
auto i80 = p.add_instruction(nary_op{"i80"}, i39, i77);
auto i81 = p.add_instruction(migraphx::op::identity{}, i55);
auto i82 = p.add_literal(2);
auto i83 = p.add_instruction(migraphx::op::identity{}, i82);
auto i84 = p.add_literal(1);
auto i85 = p.add_literal(2);
auto i86 = p.add_instruction(nary_op{"i86"}, i80, i85, i84, i83, i81);
auto i88 = p.add_instruction(migraphx::op::identity{}, i55);
auto i89 = p.add_literal(2);
auto i90 = p.add_instruction(migraphx::op::identity{}, i89);
auto i91 = p.add_literal(1);
auto i92 = p.add_literal(2);
auto i94 = p.add_instruction(nary_op{"i94"}, i39, i92, i91, i90, i88);
auto i96 = p.add_instruction(migraphx::op::identity{}, i55, i94, i75, i61, i86);
auto i97 = p.add_literal(2);
auto i98 = p.add_instruction(migraphx::op::identity{}, i97);
auto i99 = p.add_literal(3);
auto i100 = p.add_literal(1);
auto i101 = p.add_literal(2);
auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99);
p.compile(t);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) ==
t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7}));
EXPECT(t.get_streams({i48, i54, i61, output}) ==
t.get_streams({output, output, output, output}));
EXPECT(t.get_streams({i80, i86}) == t.get_streams({i80, i80}));
EXPECT(t.get_streams({i69, i75}) == t.get_streams({i69, i69}));
EXPECT(t.get_stream(i7) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i69));
EXPECT(t.get_stream(output) != t.get_stream(i80));
EXPECT(get_wait_for(i80) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i69) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i94) == get_wait_for({t.get_stream(i39)}));
EXPECT(
get_wait_for(output) ==
get_wait_for(t.get_stream(output),
{t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)}));
t.check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/pre_scheduling.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct set_stream
{
int stream = -1;
std::string name() const { return "set_stream"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
if(inputs.empty())
return {};
else
return inputs.front();
}
};
struct create_events
{
int num_of_events = 0;
std::string name() const { return "gpu::create_events"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
if(inputs.empty())
return {};
else
return inputs.front();
}
};
struct weight_func
{
weight_func()
{
weight_map["@param"] = std::make_pair(1, 1);
weight_map["@literal"] = std::make_pair(1, 1);
};
std::pair<int, int> operator()(const migraphx::operation& op)
{
if(weight_map.find(op.name()) != weight_map.end())
return weight_map[op.name()];
else
return std::make_pair(1, 0);
}
std::unordered_map<std::string, std::pair<int, int>> weight_map;
};
struct insert_instruction
{
void insert_stream(migraphx::program* p, migraphx::instruction_ref ins, int stream)
{
p->insert_instruction(ins, set_stream{stream});
}
void insert_create_events(migraphx::program*, migraphx::instruction_ref, int) {}
void insert_record_event(migraphx::program*, migraphx::instruction_ref, int) {}
void insert_wait_event(migraphx::program*, migraphx::instruction_ref, int) {}
};
struct stream_execution_target
{
struct context
{
void finish() const {}
void set_stream(int) {}
void create_events(int) {}
void record_event(int) {}
void wait_event(int) {}
};
migraphx::context ctx = context{};
std::string name() const { return "stream_execution"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::pre_scheduling{weight_func(), 2, insert_instruction{}, true}};
}
migraphx::context get_context() const { return {ctx}; }
};
TEST_CASE(test1)
{
migraphx::program p;
auto in1 =
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {32, 256, 35, 35}});
auto l1 =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 256, 1, 1}}));
auto p1 = p.add_instruction(migraphx::op::convolution{}, in1, l1);
auto l2 =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {48, 256, 1, 1}}));
auto p2 = p.add_instruction(migraphx::op::convolution{}, in1, l2);
p.add_instruction(migraphx::op::concat{1}, p1, p2);
p.compile(stream_execution_target{});
std::cout << p << std::endl;
CHECK(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "set_stream"; }) == 3);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 0; }) == 2);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 1; }) == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -27,11 +27,12 @@ struct context
<%
interface('context',
virtual('finish', returns='void', const=True),
virtual('finish', returns='void', const=True)
)
%>
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_FIND_CONCUR_HPP
#define MIGRAPHX_GUARD_FIND_CONCUR_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <unordered_map>
#include <vector>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent analysis to find concurrent instructions
/// executing in different streams.
struct find_concur
{
void get_concur(program* p,
int num_of_streams,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points);
} const;
#else
<%
interface('find_concur',
virtual('get_concur', returns='void', p = 'program*', num_of_stream = 'int', concur_instrs = 'std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&', input = 'std::unordered_map<const instruction*, int>&', const=True)
)
%>
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#define MIGRAPHX_GUARD_INSERT_INSTRUCTION_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent instruction insertion.
/// for multi-stream execution.
struct insert_instruction
{
void insert_create_events(program* p, instruction_ref ins, int num_of_events);
void insert_record_event(program* p, instruction_ref ins, int event);
void insert_wait_event(program* p, instruction_ref ins, int event);
void insert_stream(program* p, instruction_ref ins, int stream);
};
#else
<%
interface('insert_instruction',
virtual('insert_create_events', returns='void', p = 'program*', ins ='instruction_ref', input = 'int'),
virtual('insert_record_event', returns='void', p = 'program*', ins ='instruction_ref', input = 'int'),
virtual('insert_wait_event', returns='void', p = 'program*', ins = 'instruction_ref', input = 'int'),
virtual('insert_stream', returns='void', p = 'program*', ins ='instruction_ref', input = 'int')
)
%>
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#define MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct operation;
#ifdef DOXYGEN
/// An interface for target-dependent model for the scheduler
struct schedule_model
{
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
#else
<%
interface('schedule_model',
virtual('concurrency', returns='std::size_t', const=True),
virtual('sched', p='program&', ins='instruction_ref', n='std::size_t', const=True),
virtual('wait', p='program&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('record', p='program&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('weight', returns='std::size_t', op='const operation&', const=True)
)
%>
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
import string, sys, re, os
trivial = [
'std::size_t',
'instruction_ref'
]
headers = '''
#include <algorithm>
#include <cassert>
......@@ -286,7 +292,7 @@ def convert_member(d, struct_name):
member['this'] = x
if 'const' in t:
member['member_const'] = 'const'
if t.endswith(('&', '*')):
if t.endswith(('&', '*')) or t in trivial:
if use_member: member_args.append(x)
args.append(arg_name)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment