Commit 9db47771 authored by Paul's avatar Paul
Browse files

Take number of streams as the constructor

parent b75abee1
...@@ -10,6 +10,8 @@ namespace migraphx { ...@@ -10,6 +10,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void gpu_sync() { hipDeviceSynchronize(); }
using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree); using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); } std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
...@@ -43,6 +45,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -43,6 +45,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
template <class T> template <class T>
std::vector<T> read_from_gpu(const void* x, std::size_t sz) std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{ {
gpu_sync();
std::vector<T> result(sz); std::vector<T> result(sz);
auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost); auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost);
if(status != hipSuccess) if(status != hipSuccess)
...@@ -52,6 +55,7 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -52,6 +55,7 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{ {
gpu_sync();
auto result = allocate_gpu(sz, host); auto result = allocate_gpu(sz, host);
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess) if(status != hipSuccess)
...@@ -97,8 +101,6 @@ void set_device(std::size_t id) ...@@ -97,8 +101,6 @@ void set_device(std::size_t id)
MIGRAPHX_THROW("Error setting device"); MIGRAPHX_THROW("Error setting device");
} }
void gpu_sync() { hipDeviceSynchronize(); }
void copy_to_gpu(const argument& src, const argument& dst) void copy_to_gpu(const argument& src, const argument& dst)
{ {
std::size_t src_size = src.get_shape().bytes(); std::size_t src_size = src.get_shape().bytes();
......
...@@ -34,8 +34,8 @@ struct hip_device ...@@ -34,8 +34,8 @@ struct hip_device
static hip_stream_ptr create_stream() static hip_stream_ptr create_stream()
{ {
hipStream_t result = nullptr; hipStream_t result = nullptr;
// auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
auto status = hipStreamCreate(&result); auto status = hipStreamCreate(&result);
// auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream"); MIGRAPHX_THROW("Failed to allocate stream");
...@@ -81,6 +81,12 @@ struct hip_device ...@@ -81,6 +81,12 @@ struct hip_device
return rbhandle.get(); return rbhandle.get();
} }
void sync() const
{
if (s != nullptr)
hipStreamSynchronize(s.get());
}
private: private:
std::size_t id = 0; std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr; shared<hip_stream_ptr> s = nullptr;
...@@ -125,10 +131,10 @@ struct hip_device ...@@ -125,10 +131,10 @@ struct hip_device
hipStreamWaitEvent(streams.at(current_stream).get(), events.at(event).get(), 0); hipStreamWaitEvent(streams.at(current_stream).get(), events.at(event).get(), 0);
} }
void stream_sync() void sync() const
{ {
for(auto&& stream : streams) for(auto&& stream : streams)
hipStreamSynchronize(stream.get()); stream.sync();
} }
private: private:
...@@ -145,29 +151,35 @@ struct context ...@@ -145,29 +151,35 @@ struct context
{ {
} }
hip_device& get_current_device() const const hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
}
hip_device& get_current_device()
{ {
assert(current_device != nullptr); assert(current_device != nullptr);
return *current_device; return *current_device;
} }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
void set_stream(int n) const void set_stream(int n)
{ {
if(n >= 0) if(n >= 0)
get_current_device().set_stream(n); get_current_device().set_stream(n);
} }
void create_events(int num_of_events) const void create_events(int num_of_events)
{ {
get_current_device().create_events(num_of_events); get_current_device().create_events(num_of_events);
} }
void record_event(int event) const { get_current_device().record_event(event); } void record_event(int event) { get_current_device().record_event(event); }
void wait_event(int event) const { get_current_device().wait_event(event); } void wait_event(int event) { get_current_device().wait_event(event); }
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const void finish() const
{ {
get_current_device().stream_sync(); get_current_device().sync();
gpu_sync(); gpu_sync();
} }
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
namespace migraphx { namespace migraphx {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_NULL_STREAM)
struct op_info struct op_info
{ {
op_info() op_info()
...@@ -53,17 +51,6 @@ struct op_info ...@@ -53,17 +51,6 @@ struct op_info
} }
std::unordered_map<std::string, std::pair<int, int>> weight_map; std::unordered_map<std::string, std::pair<int, int>> weight_map;
}; };
struct stream_info
{
int num_of_streams()
{
if(!enabled(MIGRAPHX_DISABLE_NULL_STREAM{}))
return 0;
else
return 4;
}
};
} // namespace gpu } // namespace gpu
} // namespace migraphx } // namespace migraphx
......
...@@ -131,7 +131,6 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -131,7 +131,6 @@ migraphx::argument run_gpu(migraphx::program& p)
p.dry_run(m); p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
auto eval = p.eval(m); auto eval = p.eval(m);
p.finish();
auto ret_val = migraphx::gpu::from_gpu(eval); auto ret_val = migraphx::gpu::from_gpu(eval);
return ret_val; return ret_val;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment