Commit a129ea19 authored by mei-ye's avatar mei-ye
Browse files

remove mask and event from instuction. remove event-related methods from common and cpu context

parent 3885c9bc
...@@ -21,10 +21,6 @@ struct context ...@@ -21,10 +21,6 @@ struct context
{ {
/// Wait for any tasks in the context to complete /// Wait for any tasks in the context to complete
void finish(); void finish();
void set_stream(int ndx);
void create_events(int num_of_events);
void record_event(int event);
void wait_event(int event);
}; };
#else #else
...@@ -35,10 +31,6 @@ struct context ...@@ -35,10 +31,6 @@ struct context
* struct context * struct context
* { * {
* void finish() ; * void finish() ;
* void set_stream(int input) ;
* void create_events(int input) ;
* void record_event(int input) ;
* void wait_event(int input) ;
* }; * };
* *
*/ */
...@@ -106,30 +98,6 @@ struct context ...@@ -106,30 +98,6 @@ struct context
(*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
} }
void set_stream(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().set_stream(input);
}
void create_events(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().create_events(input);
}
void record_event(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record_event(input);
}
void wait_event(int input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_event(input);
}
friend bool is_shared(const context& private_detail_x, const context& private_detail_y) friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
{ {
return private_detail_x.private_detail_te_handle_mem_var == return private_detail_x.private_detail_te_handle_mem_var ==
...@@ -143,11 +111,7 @@ struct context ...@@ -143,11 +111,7 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void finish() = 0; virtual void finish() = 0;
virtual void set_stream(int input) = 0;
virtual void create_events(int input) = 0;
virtual void record_event(int input) = 0;
virtual void wait_event(int input) = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -180,14 +144,6 @@ struct context ...@@ -180,14 +144,6 @@ struct context
void finish() override { private_detail_te_value.finish(); } void finish() override { private_detail_te_value.finish(); }
void set_stream(int input) override { private_detail_te_value.set_stream(input); }
void create_events(int input) override { private_detail_te_value.create_events(input); }
void record_event(int input) override { private_detail_te_value.record_event(input); }
void wait_event(int input) override { private_detail_te_value.wait_event(input); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -16,12 +16,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -16,12 +16,6 @@ inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
enum instruction_mask : unsigned int
{
record_event = 0,
wait_event = 1
};
struct instruction struct instruction
{ {
instruction() {} instruction() {}
...@@ -49,15 +43,6 @@ struct instruction ...@@ -49,15 +43,6 @@ struct instruction
int get_stream() const; int get_stream() const;
void set_stream(int); void set_stream(int);
int get_event() const;
void set_event(int);
void add_mask(instruction_mask m)
{
if((mask & (1u << m)) == 0)
mask += (1u << m);
}
bool has_mask(instruction_mask m) const { return ((mask & (1u << m)) != 0); }
std::string name() const; std::string name() const;
const std::vector<instruction_ref>& inputs() const; const std::vector<instruction_ref>& inputs() const;
...@@ -111,9 +96,7 @@ struct instruction ...@@ -111,9 +96,7 @@ struct instruction
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
int stream = -1; int stream = -1;
unsigned int mask = 0;
int event = -1;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PRE_SCHEDULING) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PRE_SCHEDULING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EVENT_AS_INSTRUCTION)
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -89,8 +89,6 @@ const literal& instruction::get_literal() const ...@@ -89,8 +89,6 @@ const literal& instruction::get_literal() const
int instruction::get_stream() const { return stream; } int instruction::get_stream() const { return stream; }
void instruction::set_stream(int s) { stream = s; } void instruction::set_stream(int s) { stream = s; }
int instruction::get_event() const { return event; }
void instruction::set_event(int e) { event = e; }
const operation& instruction::get_operator() const { return op; } const operation& instruction::get_operator() const { return op; }
......
...@@ -149,18 +149,15 @@ void dom_info::compute_dom(bool reversed) ...@@ -149,18 +149,15 @@ void dom_info::compute_dom(bool reversed)
bool dom_info::is_split_point(instruction_ref ins) bool dom_info::is_split_point(instruction_ref ins)
{ {
if(ins->has_mask(record_event)) std::set<int> stream_set;
for(auto&& arg : ins->outputs())
{ {
std::set<int> stream_set; int arg_stream = arg->get_stream();
for(auto&& arg : ins->outputs()) if(arg_stream >= 0)
{ stream_set.insert(arg_stream);
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
}
if(stream_set.size() > 1)
return true;
} }
if(stream_set.size() > 1)
return true;
return false; return false;
} }
...@@ -168,18 +165,15 @@ bool dom_info::is_split_point(instruction_ref ins) ...@@ -168,18 +165,15 @@ bool dom_info::is_split_point(instruction_ref ins)
// inputs that are executed in different streams. // inputs that are executed in different streams.
bool dom_info::is_merge_point(instruction_ref ins) bool dom_info::is_merge_point(instruction_ref ins)
{ {
if(ins->has_mask(wait_event)) std::set<int> stream_set;
for(auto&& arg : ins->inputs())
{ {
std::set<int> stream_set; int arg_stream = arg->get_stream();
for(auto&& arg : ins->inputs()) if(arg_stream >= 0)
{ stream_set.insert(arg_stream);
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
}
if(stream_set.size() > 1)
return true;
} }
if(stream_set.size() > 1)
return true;
return false; return false;
} }
......
...@@ -224,9 +224,9 @@ void pre_scheduling_impl::splice(std::list<dag_node*>& sorted_nodes) ...@@ -224,9 +224,9 @@ void pre_scheduling_impl::splice(std::list<dag_node*>& sorted_nodes)
// //
void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes) void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes)
{ {
int event = 0; int event = 0;
int last_stream = -1; int last_stream = -1;
bool enable_event_as_instr = enabled(MIGRAPHX_ENABLE_EVENT_AS_INSTRUCTION{});
for(auto&& node : sorted_nodes) for(auto&& node : sorted_nodes)
{ {
instruction_ref ins = node->ins; instruction_ref ins = node->ins;
...@@ -250,22 +250,14 @@ void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes) ...@@ -250,22 +250,14 @@ void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes)
if(!has_mask(arg, record_event)) if(!has_mask(arg, record_event))
{ {
events.push_back(event); events.push_back(event);
arg->set_event(event); insert_instr.insert_record_event(p_program, std::next(arg), event);
arg->add_mask(record_event);
if(enable_event_as_instr)
insert_instr.insert_record_event(p_program, std::next(arg), event);
event++; event++;
} }
ins->add_mask(wait_event);
add_mask(arg, record_event); add_mask(arg, record_event);
add_mask(ins, wait_event); add_mask(ins, wait_event);
} }
if(enable_event_as_instr) for(auto&& i : events)
{ insert_instr.insert_wait_event(p_program, ins, i);
for(auto&& i : events)
insert_instr.insert_wait_event(p_program, ins, i);
}
} }
} }
......
...@@ -79,6 +79,12 @@ struct stream_info ...@@ -79,6 +79,12 @@ struct stream_info
int max_cycle; int max_cycle;
}; };
enum instruction_mask : unsigned int
{
record_event = 0,
wait_event = 1
};
struct pre_scheduling_impl struct pre_scheduling_impl
{ {
pre_scheduling_impl(program* p, pre_scheduling_impl(program* p,
......
...@@ -55,10 +55,6 @@ static void print_instruction(std::ostream& os, ...@@ -55,10 +55,6 @@ static void print_instruction(std::ostream& os,
} }
if(ins->get_stream() >= 0) if(ins->get_stream() >= 0)
os << "(stream=" << ins->get_stream() << ")"; os << "(stream=" << ins->get_stream() << ")";
if(ins->has_mask(wait_event))
os << " wait ";
if(ins->has_mask(record_event))
os << " record=" << ins->get_event();
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
} }
...@@ -329,14 +325,10 @@ void program::compile(const target& t, tracer trace) ...@@ -329,14 +325,10 @@ void program::compile(const target& t, tracer trace)
void program::finalize() void program::finalize()
{ {
int max_event = -1;
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
ins->finalize(this->impl->ctx); ins->finalize(this->impl->ctx);
max_event = std::max(max_event, ins->get_event());
} }
if(max_event >= 0)
this->impl->ctx.create_events(max_event + 1);
} }
void program::finish() { this->impl->ctx.finish(); } void program::finish() { this->impl->ctx.finish(); }
...@@ -352,12 +344,9 @@ argument generic_eval(const program& p, ...@@ -352,12 +344,9 @@ argument generic_eval(const program& p,
results.reserve(p.size() * 2); results.reserve(p.size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
bool enable_event_as_instr = enabled(MIGRAPHX_ENABLE_EVENT_AS_INSTRUCTION{});
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
int stream = ins->get_stream();
// ctx.set_stream(stream);
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); })); results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); }));
...@@ -385,25 +374,9 @@ argument generic_eval(const program& p, ...@@ -385,25 +374,9 @@ argument generic_eval(const program& p,
return results[i]; return results[i];
}); });
if(!enable_event_as_instr && ins->has_mask(wait_event))
{
for(auto&& arg : ins->inputs())
{
int arg_s = arg->get_stream();
if((arg_s < 0) || (arg_s == stream))
continue;
int event = arg->get_event();
assert(event >= 0);
ctx.wait_event(event);
}
}
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return ins->get_operator().compute(ctx, ins->get_shape(), values); return ins->get_operator().compute(ctx, ins->get_shape(), values);
})); }));
if(!enable_event_as_instr && ins->has_mask(record_event))
ctx.record_event(ins->get_event());
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
......
...@@ -10,10 +10,6 @@ namespace cpu { ...@@ -10,10 +10,6 @@ namespace cpu {
struct context struct context
{ {
void finish() {} void finish() {}
void set_stream(int) {}
void create_events(int) {}
void record_event(int) {}
void wait_event(int) {}
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -115,6 +115,7 @@ struct hip_device ...@@ -115,6 +115,7 @@ struct hip_device
} }
void record_event(int event) void record_event(int event)
{ {
create_events(event + 1);
hipEventRecord(events.at(event).get(), streams.at(current_stream).get()); hipEventRecord(events.at(event).get(), streams.at(current_stream).get());
} }
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/verify_args.hpp> #include <migraphx/verify_args.hpp>
migraphx::program create_program(bool is_cpu) migraphx::program create_program()
{ {
migraphx::program p; migraphx::program p;
auto in1 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {32, 64, 1, 1}}); auto in1 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {32, 64, 1, 1}});
...@@ -16,23 +16,14 @@ migraphx::program create_program(bool is_cpu) ...@@ -16,23 +16,14 @@ migraphx::program create_program(bool is_cpu)
auto p1 = p.add_instruction(migraphx::op::convolution{}, in1, in2); auto p1 = p.add_instruction(migraphx::op::convolution{}, in1, in2);
auto in3 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {64, 64, 1, 1}}); auto in3 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {64, 64, 1, 1}});
auto p2 = p.add_instruction(migraphx::op::convolution{}, in1, in3); auto p2 = p.add_instruction(migraphx::op::convolution{}, in1, in3);
if(is_cpu) p.add_instruction(migraphx::op::concat{1}, p1, p2);
{
p2->set_event(0);
p2->add_mask(migraphx::record_event);
}
auto p3 = p.add_instruction(migraphx::op::concat{1}, p1, p2);
if(is_cpu)
{
p3->add_mask(migraphx::wait_event);
}
return p; return p;
} }
migraphx::argument run_gpu() migraphx::argument run_gpu()
{ {
setenv("MIGRAPHX_DISABLE_NULL_STREAM", "1", 1); setenv("MIGRAPHX_DISABLE_NULL_STREAM", "1", 1);
migraphx::program p = create_program(false); migraphx::program p = create_program();
p.compile(migraphx::gpu::target{}); p.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
...@@ -46,7 +37,7 @@ migraphx::argument run_gpu() ...@@ -46,7 +37,7 @@ migraphx::argument run_gpu()
migraphx::argument run_cpu() migraphx::argument run_cpu()
{ {
migraphx::program p = create_program(true); migraphx::program p = create_program();
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
......
...@@ -648,7 +648,6 @@ TEST_CASE(concurrent_test) ...@@ -648,7 +648,6 @@ TEST_CASE(concurrent_test)
auto p1 = p.add_instruction(pass_op{}, a1, in); auto p1 = p.add_instruction(pass_op{}, a1, in);
p.insert_instruction(p1, set_stream{0}); p.insert_instruction(p1, set_stream{0});
p1->set_stream(0); p1->set_stream(0);
p1->add_mask(migraphx::record_event);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p2->set_stream(0); p2->set_stream(0);
...@@ -659,25 +658,19 @@ TEST_CASE(concurrent_test) ...@@ -659,25 +658,19 @@ TEST_CASE(concurrent_test)
auto p3 = p.add_instruction(pass_op{}, a3, p1); auto p3 = p.add_instruction(pass_op{}, a3, p1);
p3->set_stream(1); p3->set_stream(1);
p.insert_instruction(p3, set_stream{1}); p.insert_instruction(p3, set_stream{1});
p3->add_mask(migraphx::wait_event);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p5 = p.add_instruction(pass_op{}, a5, p3); auto p5 = p.add_instruction(pass_op{}, a5, p3);
p5->set_stream(1); p5->set_stream(1);
p5->add_mask(migraphx::record_event);
auto a6 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a6 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p6 = p.add_instruction(pass_op{}, a6, p1); auto p6 = p.add_instruction(pass_op{}, a6, p1);
p6->set_stream(2); p6->set_stream(2);
p6->add_mask(migraphx::wait_event);
p.insert_instruction(p6, set_stream{2}); p.insert_instruction(p6, set_stream{2});
auto a7 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a7 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p7 = p.add_instruction(pass_op{}, a7, p6); auto p7 = p.add_instruction(pass_op{}, a7, p6);
p7->set_stream(2); p7->set_stream(2);
p7->add_mask(migraphx::record_event);
auto a8 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a8 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p8 = p.add_instruction(migraphx::op::concat{0}, a8, p4, p5, p7); auto p8 = p.add_instruction(migraphx::op::concat{0}, a8, p4, p5, p7);
;
p8->set_stream(0); p8->set_stream(0);
p8->add_mask(migraphx::wait_event);
p.insert_instruction(p8, set_stream{0}); p.insert_instruction(p8, set_stream{0});
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 960); CHECK(p.get_parameter_shape("scratch").bytes() == 960);
......
...@@ -101,12 +101,6 @@ TEST_CASE(test1) ...@@ -101,12 +101,6 @@ TEST_CASE(test1)
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "set_stream"; }) == 3); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "set_stream"; }) == 3);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 0; }) == 2); CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 0; }) == 2);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 1; }) == 1); CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) { return ins.get_stream() == 1; }) == 1);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) {
return ins.has_mask(migraphx::record_event);
}) == 1);
CHECK(std::count_if(p.begin(), p.end(), [](auto&& ins) {
return ins.has_mask(migraphx::wait_event);
}) == 1);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -21,10 +21,6 @@ struct context ...@@ -21,10 +21,6 @@ struct context
{ {
/// Wait for any tasks in the context to complete /// Wait for any tasks in the context to complete
void finish(); void finish();
void set_stream(int ndx);
void create_events(int num_of_events);
void record_event(int event);
void wait_event(int event);
}; };
#else #else
...@@ -32,10 +28,6 @@ struct context ...@@ -32,10 +28,6 @@ struct context
<% <%
interface('context', interface('context',
virtual('finish', returns='void'), virtual('finish', returns='void'),
virtual('set_stream', returns='void', input = 'int'),
virtual('create_events', returns='void', input = 'int'),
virtual('record_event', returns='void', input = 'int'),
virtual('wait_event', returns='void', input = 'int'),
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment