Commit 63c6b0a5 authored by ravil-mobile's avatar ravil-mobile
Browse files

Fixed timing in benchmarking

parent 4637621a
......@@ -232,9 +232,13 @@ argument from_gpu(const argument& arg)
void set_device(std::size_t id)
{
auto status = hipSetDevice(id);
if(status != hipSuccess)
MIGRAPHX_THROW("Error setting device");
static std::size_t curr_id{0};
if (curr_id != id) {
curr_id = id;
auto status = hipSetDevice(curr_id);
if(status != hipSuccess)
MIGRAPHX_THROW("Error setting device");
}
}
void gpu_sync()
......
......@@ -55,17 +55,25 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
op.compute(ctx, output, args);
ctx.finish();
};
gctx.enable_perf_measurement();
run();
double host_time = 0.0;
double device_time = 0.0;
shared<hip_event_ptr> start = gctx.create_event_for_timing();
shared<hip_event_ptr> stop = gctx.create_event_for_timing();
gctx.get_stream().record(start.get());
for(auto i : range(n))
{
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
op.compute(ctx, output, args);
}
return std::make_pair(host_time / n, device_time / n);
gctx.get_stream().record(stop.get());
auto status = hipEventSynchronize(stop.get());
if (status != hipSuccess) { MIGRAPHX_THROW("Failed to `hipEventSynchronize`: " + hip_error(status)); }
float milliseconds = 0.0;
status = hipEventElapsedTime(&milliseconds, start.get(), stop.get());
if (status != hipSuccess) { MIGRAPHX_THROW("Failed to `hipEventElapsedTime`: " + hip_error(status)); }
return std::make_pair(milliseconds, milliseconds);
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment