Commit 3f84bc67 authored by Paul's avatar Paul
Browse files

Formatting

parent 6df03163
...@@ -47,8 +47,12 @@ shape miopen_convolution::compile(context& ctx, ...@@ -47,8 +47,12 @@ shape miopen_convolution::compile(context& ctx,
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize( miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
ctx.get_stream().get_miopen(), w_desc.get(), x_desc.get(), cd.get(), y_desc.get(), &workspace_size); w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape())); auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
......
...@@ -82,7 +82,8 @@ struct fusion ...@@ -82,7 +82,8 @@ struct fusion
// int algo_count = 1; // int algo_count = 1;
// miopenConvFwdAlgorithm_t algo; // miopenConvFwdAlgorithm_t algo;
// miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo); // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
// miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size, algo); // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
// algo);
return shape{shape::int8_type, {ws_size}}; return shape{shape::int8_type, {ws_size}};
} }
......
...@@ -10,32 +10,22 @@ namespace gpu { ...@@ -10,32 +10,22 @@ namespace gpu {
struct hip_device struct hip_device
{ {
hip_device() hip_device() { add_stream(); }
{
add_stream();
}
hip_device(std::size_t id) hip_device(std::size_t id) : device_id(id) { add_stream(); }
: device_id(id)
{
add_stream();
}
struct stream struct stream
{ {
using hip_stream_ptr = MIGRAPH_MANAGE_PTR(hipStream_t, hipStreamDestroy); using hip_stream_ptr = MIGRAPH_MANAGE_PTR(hipStream_t, hipStreamDestroy);
stream() stream() {}
{}
stream(std::size_t device_number) stream(std::size_t device_number) : id(device_number) {}
: id(device_number)
{}
static hip_stream_ptr create_stream() static hip_stream_ptr create_stream()
{ {
hipStream_t result = nullptr; hipStream_t result = nullptr;
auto status = hipStreamCreate(&result); auto status = hipStreamCreate(&result);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPH_THROW("Failed to allocate stream"); MIGRAPH_THROW("Failed to allocate stream");
return hip_stream_ptr{result}; return hip_stream_ptr{result};
...@@ -68,39 +58,28 @@ struct hip_device ...@@ -68,39 +58,28 @@ struct hip_device
return rbhandle.get(); return rbhandle.get();
} }
private: private:
std::size_t id = 0; std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr; shared<hip_stream_ptr> s = nullptr;
shared<miopen_handle> mihandle = nullptr; shared<miopen_handle> mihandle = nullptr;
shared<rocblas_handle_ptr> rbhandle = nullptr; shared<rocblas_handle_ptr> rbhandle = nullptr;
}; };
void add_stream() void add_stream() { streams.emplace_back(device_id); }
{
streams.emplace_back(device_id);
}
stream& get_stream() stream& get_stream() { return streams.at(current_stream); }
{
return streams.at(current_stream);
}
void set_stream(std::size_t n) void set_stream(std::size_t n) { current_stream = n; }
{
current_stream = n;
}
private: private:
std::size_t device_id = 0; std::size_t device_id = 0;
std::size_t current_stream = 0; std::size_t current_stream = 0;
std::vector<stream> streams; std::vector<stream> streams;
}; };
struct context struct context
{ {
context(std::size_t n=0) context(std::size_t n = 0) : current_device(std::make_shared<hip_device>(n)) {}
: current_device(std::make_shared<hip_device>(n))
{}
hip_device& get_current_device() hip_device& get_current_device()
{ {
...@@ -108,14 +87,12 @@ struct context ...@@ -108,14 +87,12 @@ struct context
return *current_device; return *current_device;
} }
hip_device::stream& get_stream() hip_device::stream& get_stream() { return get_current_device().get_stream(); }
{
return get_current_device().get_stream();
}
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const { gpu_sync(); } void finish() const { gpu_sync(); }
private:
private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
}; };
......
...@@ -54,9 +54,6 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -54,9 +54,6 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
migraph::context target::get_context() const migraph::context target::get_context() const { return context{}; }
{
return context{};
}
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment