Commit 9db47771 authored by Paul's avatar Paul
Browse files

Take number of streams as the constructor

parent b75abee1
......@@ -10,6 +10,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void gpu_sync() { hipDeviceSynchronize(); }
using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
......@@ -43,6 +45,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
template <class T>
std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{
gpu_sync();
std::vector<T> result(sz);
auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost);
if(status != hipSuccess)
......@@ -52,6 +55,7 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{
gpu_sync();
auto result = allocate_gpu(sz, host);
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess)
......@@ -97,8 +101,6 @@ void set_device(std::size_t id)
MIGRAPHX_THROW("Error setting device");
}
void gpu_sync() { hipDeviceSynchronize(); }
void copy_to_gpu(const argument& src, const argument& dst)
{
std::size_t src_size = src.get_shape().bytes();
......
......@@ -34,8 +34,8 @@ struct hip_device
static hip_stream_ptr create_stream()
{
hipStream_t result = nullptr;
// auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
auto status = hipStreamCreate(&result);
// auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream");
......@@ -81,6 +81,12 @@ struct hip_device
return rbhandle.get();
}
void sync() const
{
if (s != nullptr)
hipStreamSynchronize(s.get());
}
private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
......@@ -125,10 +131,10 @@ struct hip_device
hipStreamWaitEvent(streams.at(current_stream).get(), events.at(event).get(), 0);
}
void stream_sync()
void sync() const
{
for(auto&& stream : streams)
hipStreamSynchronize(stream.get());
stream.sync();
}
private:
......@@ -145,29 +151,35 @@ struct context
{
}
hip_device& get_current_device() const
const hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
}
hip_device& get_current_device()
{
assert(current_device != nullptr);
return *current_device;
}
hip_device::stream& get_stream() { return get_current_device().get_stream(); }
void set_stream(int n) const
void set_stream(int n)
{
if(n >= 0)
get_current_device().set_stream(n);
}
void create_events(int num_of_events) const
void create_events(int num_of_events)
{
get_current_device().create_events(num_of_events);
}
void record_event(int event) const { get_current_device().record_event(event); }
void wait_event(int event) const { get_current_device().wait_event(event); }
void record_event(int event) { get_current_device().record_event(event); }
void wait_event(int event) { get_current_device().wait_event(event); }
std::vector<argument> literals{};
void finish() const
{
get_current_device().stream_sync();
get_current_device().sync();
gpu_sync();
}
......
......@@ -8,8 +8,6 @@
namespace migraphx {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_NULL_STREAM)
struct op_info
{
op_info()
......@@ -53,17 +51,6 @@ struct op_info
}
std::unordered_map<std::string, std::pair<int, int>> weight_map;
};
struct stream_info
{
int num_of_streams()
{
if(!enabled(MIGRAPHX_DISABLE_NULL_STREAM{}))
return 0;
else
return 4;
}
};
} // namespace gpu
} // namespace migraphx
......
......@@ -131,7 +131,6 @@ migraphx::argument run_gpu(migraphx::program& p)
p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context()));
auto eval = p.eval(m);
p.finish();
auto ret_val = migraphx::gpu::from_gpu(eval);
return ret_val;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment