Commit 9167146e authored by Paul's avatar Paul
Browse files

Fix error with number of streams

parent 07d27ac2
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/machine_model.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -18,9 +17,9 @@ struct hip_device ...@@ -18,9 +17,9 @@ struct hip_device
{ {
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
hip_device() { add_streams(); } hip_device() {}
hip_device(std::size_t id) : device_id(id) { add_streams(); } hip_device(std::size_t id, std::size_t n) : device_id(id) { add_streams(n); }
struct stream struct stream
{ {
...@@ -35,7 +34,8 @@ struct hip_device ...@@ -35,7 +34,8 @@ struct hip_device
static hip_stream_ptr create_stream() static hip_stream_ptr create_stream()
{ {
hipStream_t result = nullptr; hipStream_t result = nullptr;
auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking); // auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
auto status = hipStreamCreate(&result);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream"); MIGRAPHX_THROW("Failed to allocate stream");
...@@ -97,14 +97,18 @@ struct hip_device ...@@ -97,14 +97,18 @@ struct hip_device
return hip_event_ptr{event}; return hip_event_ptr{event};
} }
void add_streams() void add_streams(std::size_t num_of_streams)
{ {
int num_of_streams = stream_info().num_of_streams();
assert(streams.empty()); assert(streams.empty());
for(int i = 0; i < num_of_streams; ++i) for(int i = 0; i < num_of_streams; ++i)
streams.emplace_back(device_id); streams.emplace_back(device_id);
} }
std::size_t nstreams() const
{
return streams.size();
}
stream& get_stream() { return streams.at(current_stream); } stream& get_stream() { return streams.at(current_stream); }
void set_stream(std::size_t n) { current_stream = n; } void set_stream(std::size_t n) { current_stream = n; }
...@@ -139,7 +143,7 @@ struct hip_device ...@@ -139,7 +143,7 @@ struct hip_device
struct context struct context
{ {
context(std::size_t n = 0) : current_device(std::make_shared<hip_device>(n)) {} context(std::size_t device_id = 0, std::size_t n = 4) : current_device(std::make_shared<hip_device>(device_id, n)) {}
hip_device& get_current_device() const hip_device& get_current_device() const
{ {
......
...@@ -31,7 +31,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -31,7 +31,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
std::function<std::pair<int, int>(const operation&)> weight_func = op_info(); std::function<std::pair<int, int>(const operation&)> weight_func = op_info();
int num_of_streams = stream_info().num_of_streams(); int num_of_streams = ctx.get_current_device().nstreams();
// clang-format off // clang-format off
return return
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment