Commit 09d332d9 authored by wsttiger's avatar wsttiger
Browse files

Fixed up ROCBLAS implementation of GEMM on GPU

parent afa4a833
...@@ -20,10 +20,6 @@ endif() ...@@ -20,10 +20,6 @@ endif()
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc # Override clang-tidy to not find the version from hcc
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
find_package(miopen) find_package(miopen)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen) if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
...@@ -11,6 +15,7 @@ add_library(migraph_miopen ...@@ -11,6 +15,7 @@ add_library(migraph_miopen
miopen_target.cpp miopen_target.cpp
miopen_lowering.cpp miopen_lowering.cpp
miopen_write_literals.cpp miopen_write_literals.cpp
rocblas.cpp
) )
rocm_clang_tidy_check(migraph_miopen) rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen rocblas) target_link_libraries(migraph_miopen migraph MIOpen rocblas)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP #define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/miopen/miopen.hpp> #include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/rocblas.hpp>
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
...@@ -9,6 +10,7 @@ namespace miopen { ...@@ -9,6 +10,7 @@ namespace miopen {
struct miopen_context struct miopen_context
{ {
shared<miopen_handle> handle; shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
}; };
} // namespace miopen } // namespace miopen
......
#include <rocblas.h>
#include <migraph/miopen/miopen_lowering.hpp> #include <migraph/miopen/miopen_lowering.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
...@@ -7,16 +8,12 @@ ...@@ -7,16 +8,12 @@
#include <migraph/miopen/hip.hpp> #include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <rocblas.h> #include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution struct miopen_convolution
{ {
convolution op; convolution op;
...@@ -28,7 +25,7 @@ struct miopen_convolution ...@@ -28,7 +25,7 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -80,7 +77,7 @@ struct miopen_pooling ...@@ -80,7 +77,7 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -113,7 +110,7 @@ struct miopen_add ...@@ -113,7 +110,7 @@ struct miopen_add
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[1].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
...@@ -160,10 +157,10 @@ struct miopen_gemm ...@@ -160,10 +157,10 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
rocblas_handle rochandle; // rocblas_handle_ptr handle_ptr = create_rocblas_handle_ptr();
rocblas_create_handle(&rochandle); auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1]; rocblas_int lda = args[0].get_shape().lens()[1];
...@@ -172,7 +169,7 @@ struct miopen_gemm ...@@ -172,7 +169,7 @@ struct miopen_gemm
rocblas_int m = output_shape.lens()[0]; rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle, rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none, rocblas_operation_none,
rocblas_operation_none, rocblas_operation_none,
n, n,
...@@ -200,7 +197,7 @@ struct miopen_relu ...@@ -200,7 +197,7 @@ struct miopen_relu
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
......
...@@ -15,7 +15,7 @@ std::string miopen_target::name() const { return "miopen"; } ...@@ -15,7 +15,7 @@ std::string miopen_target::name() const { return "miopen"; }
context miopen_target::get_context() const context miopen_target::get_context() const
{ {
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))}; return miopen_context{share(make_obj<miopen_handle>(&miopenCreate)), share(create_rocblas_handle_ptr())};
} }
} // namespace miopen } // namespace miopen
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment