"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "f9b3189391d1441fea16d8b61e5e67a2c38f4442"
Commit 8879cc93 authored by Paul's avatar Paul
Browse files

Create events during finalize

parent 5b6abcc3
......@@ -87,6 +87,22 @@ struct hip_device
hipStreamSynchronize(s.get());
}
void wait(hipEvent_t event)
{
setup();
auto status = hipStreamWaitEvent(get(), event, 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait.");
}
void record(hipEvent_t event)
{
setup();
auto status = hipEventRecord(event, get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record.");
}
private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
......@@ -115,20 +131,25 @@ struct hip_device
stream& get_stream() { return streams.at(current_stream); }
void set_stream(std::size_t n) { current_stream = n; }
void create_events(int num_of_events)
void create_events(std::size_t num_of_events)
{
for(int i = events.size(); i < num_of_events; ++i)
for(int i = events.size(); i < num_of_events+1; ++i)
events.emplace_back(create_event());
}
void record_event(int event)
void record_event(std::size_t event)
{
create_events(event + 1);
hipEventRecord(events.at(event).get(), streams.at(current_stream).get());
streams.at(current_stream).record(events.at(event).get());
}
void wait_event(int event)
void wait_event(std::size_t event)
{
hipStreamWaitEvent(streams.at(current_stream).get(), events.at(event).get(), 0);
streams.at(current_stream).wait(events.at(event).get());
}
void check_events(std::size_t n) const
{
if (n > events.size())
MIGRAPHX_THROW("The number of waits exceed the number of records.");
}
void sync() const
......@@ -164,14 +185,15 @@ struct context
}
hip_device::stream& get_stream() { return get_current_device().get_stream(); }
void set_stream(int n)
void set_stream(std::size_t n)
{
if(n >= 0)
get_current_device().set_stream(n);
get_current_device().set_stream(n);
}
void create_events(int num_of_events) { get_current_device().create_events(num_of_events); }
void record_event(int event) { get_current_device().record_event(event); }
void wait_event(int event) { get_current_device().wait_event(event); }
void create_events(std::size_t num_of_events) { get_current_device().create_events(num_of_events); }
void check_events(std::size_t n) const { get_current_device().check_events(n); }
void record_event(std::size_t event) { get_current_device().record_event(event); }
void wait_event(std::size_t event) { get_current_device().wait_event(event); }
std::vector<argument> literals{};
void finish() const
......
......@@ -45,6 +45,12 @@ struct record_event
ctx.record_event(event);
return {};
}
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(event >= 0);
ctx.create_events(event);
}
};
struct wait_event
......@@ -63,6 +69,12 @@ struct wait_event
ctx.wait_event(event);
return {};
}
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(event >= 0);
ctx.check_events(event);
}
};
struct set_stream
......
......@@ -15,7 +15,7 @@ struct insert_instruction_gpu
{
void insert_create_events(program* p, instruction_ref ins, int num_of_events)
{
p->insert_instruction(ins, create_events{num_of_events});
// p->insert_instruction(ins, create_events{num_of_events});
}
void insert_record_event(program* p, instruction_ref ins, int event)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment