Commit 09d332d9 authored by wsttiger's avatar wsttiger
Browse files

Fixed up ROCBLAS implementation of GEMM on GPU

parent afa4a833
......@@ -20,10 +20,6 @@ endif()
add_compile_options(-std=c++14)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc
......
......@@ -2,6 +2,10 @@
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
find_package(miopen)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
......@@ -11,6 +15,7 @@ add_library(migraph_miopen
miopen_target.cpp
miopen_lowering.cpp
miopen_write_literals.cpp
rocblas.cpp
)
rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen rocblas)
......
......@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/rocblas.hpp>
namespace migraph {
namespace miopen {
......@@ -9,6 +10,7 @@ namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
};
} // namespace miopen
......
#include <rocblas.h>
#include <migraph/miopen/miopen_lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
......@@ -7,16 +8,12 @@
#include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/iterator_for.hpp>
#include <rocblas.h>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace migraph {
namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution
{
convolution op;
......@@ -28,7 +25,7 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
......@@ -80,7 +77,7 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)});
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
......@@ -113,7 +110,7 @@ struct miopen_add
return inputs.at(0);
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{
if(args[1].get_shape().broadcasted())
{
......@@ -160,10 +157,10 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{
rocblas_handle rochandle;
rocblas_create_handle(&rochandle);
// rocblas_handle_ptr handle_ptr = create_rocblas_handle_ptr();
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1.0f;
float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1];
......@@ -172,7 +169,7 @@ struct miopen_gemm
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle,
rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none,
rocblas_operation_none,
n,
......@@ -200,7 +197,7 @@ struct miopen_relu
return inputs.at(1);
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
......
......@@ -15,7 +15,7 @@ std::string miopen_target::name() const { return "miopen"; }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate)), share(create_rocblas_handle_ptr())};
}
} // namespace miopen
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment