Commit 8879cc93 authored by Paul's avatar Paul
Browse files

Create events during finalize

parent 5b6abcc3
...@@ -87,6 +87,22 @@ struct hip_device ...@@ -87,6 +87,22 @@ struct hip_device
hipStreamSynchronize(s.get()); hipStreamSynchronize(s.get());
} }
void wait(hipEvent_t event)
{
setup();
auto status = hipStreamWaitEvent(get(), event, 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait.");
}
void record(hipEvent_t event)
{
setup();
auto status = hipEventRecord(event, get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record.");
}
private: private:
std::size_t id = 0; std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr; shared<hip_stream_ptr> s = nullptr;
...@@ -115,20 +131,25 @@ struct hip_device ...@@ -115,20 +131,25 @@ struct hip_device
stream& get_stream() { return streams.at(current_stream); } stream& get_stream() { return streams.at(current_stream); }
void set_stream(std::size_t n) { current_stream = n; } void set_stream(std::size_t n) { current_stream = n; }
void create_events(int num_of_events) void create_events(std::size_t num_of_events)
{ {
for(int i = events.size(); i < num_of_events; ++i) for(int i = events.size(); i < num_of_events+1; ++i)
events.emplace_back(create_event()); events.emplace_back(create_event());
} }
void record_event(int event) void record_event(std::size_t event)
{ {
create_events(event + 1); streams.at(current_stream).record(events.at(event).get());
hipEventRecord(events.at(event).get(), streams.at(current_stream).get());
} }
void wait_event(int event) void wait_event(std::size_t event)
{ {
hipStreamWaitEvent(streams.at(current_stream).get(), events.at(event).get(), 0); streams.at(current_stream).wait(events.at(event).get());
}
void check_events(std::size_t n) const
{
if (n > events.size())
MIGRAPHX_THROW("The number of waits exceed the number of records.");
} }
void sync() const void sync() const
...@@ -164,14 +185,15 @@ struct context ...@@ -164,14 +185,15 @@ struct context
} }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
void set_stream(int n) void set_stream(std::size_t n)
{ {
if(n >= 0) get_current_device().set_stream(n);
get_current_device().set_stream(n);
} }
void create_events(int num_of_events) { get_current_device().create_events(num_of_events); } void create_events(std::size_t num_of_events) { get_current_device().create_events(num_of_events); }
void record_event(int event) { get_current_device().record_event(event); } void check_events(std::size_t n) const { get_current_device().check_events(n); }
void wait_event(int event) { get_current_device().wait_event(event); }
void record_event(std::size_t event) { get_current_device().record_event(event); }
void wait_event(std::size_t event) { get_current_device().wait_event(event); }
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const void finish() const
......
...@@ -45,6 +45,12 @@ struct record_event ...@@ -45,6 +45,12 @@ struct record_event
ctx.record_event(event); ctx.record_event(event);
return {}; return {};
} }
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(event >= 0);
ctx.create_events(event);
}
}; };
struct wait_event struct wait_event
...@@ -63,6 +69,12 @@ struct wait_event ...@@ -63,6 +69,12 @@ struct wait_event
ctx.wait_event(event); ctx.wait_event(event);
return {}; return {};
} }
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(event >= 0);
ctx.check_events(event);
}
}; };
struct set_stream struct set_stream
......
...@@ -15,7 +15,7 @@ struct insert_instruction_gpu ...@@ -15,7 +15,7 @@ struct insert_instruction_gpu
{ {
void insert_create_events(program* p, instruction_ref ins, int num_of_events) void insert_create_events(program* p, instruction_ref ins, int num_of_events)
{ {
p->insert_instruction(ins, create_events{num_of_events}); // p->insert_instruction(ins, create_events{num_of_events});
} }
void insert_record_event(program* p, instruction_ref ins, int event) void insert_record_event(program* p, instruction_ref ins, int event)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment