"...composable_kernel.git" did not exist on "863e069bbb5fbebf6d8bd26252385b8763804bf8"
Unverified Commit 1e66a536 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #18 from ROCmSoftwarePlatform/rocblas-integration

Rocblas integration
parents a47f8e4b 061bc56e
...@@ -45,6 +45,11 @@ struct shape ...@@ -45,6 +45,11 @@ struct shape
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE) MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE #undef MIGRAPH_SHAPE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
{
};
shape(); shape();
shape(type_t t); shape(type_t t);
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
find_package(miopen) find_package(miopen)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen) if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
...@@ -11,7 +15,8 @@ add_library(migraph_miopen ...@@ -11,7 +15,8 @@ add_library(migraph_miopen
miopen_target.cpp miopen_target.cpp
miopen_lowering.cpp miopen_lowering.cpp
miopen_write_literals.cpp miopen_write_literals.cpp
rocblas.cpp
) )
rocm_clang_tidy_check(migraph_miopen) rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen) target_link_libraries(migraph_miopen migraph MIOpen rocblas)
target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP #define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/miopen/miopen.hpp> #include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/rocblas.hpp>
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
...@@ -9,6 +10,7 @@ namespace miopen { ...@@ -9,6 +10,7 @@ namespace miopen {
struct miopen_context struct miopen_context
{ {
shared<miopen_handle> handle; shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
}; };
} // namespace miopen } // namespace miopen
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <rocblas.h>
namespace migraph {
namespace miopen {
using rocblas_handle_ptr = MIGRAPH_MANAGE_PTR(rocblas_handle, rocblas_destroy_handle);
rocblas_handle_ptr create_rocblas_handle_ptr();
} // namespace miopen
} // namespace migraph
#endif
#include <rocblas.h>
#include <migraph/miopen/miopen_lowering.hpp> #include <migraph/miopen/miopen_lowering.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
...@@ -7,15 +8,12 @@ ...@@ -7,15 +8,12 @@
#include <migraph/miopen/hip.hpp> #include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution struct miopen_convolution
{ {
convolution op; convolution op;
...@@ -27,7 +25,7 @@ struct miopen_convolution ...@@ -27,7 +25,7 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -79,7 +77,7 @@ struct miopen_pooling ...@@ -79,7 +77,7 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -112,7 +110,7 @@ struct miopen_add ...@@ -112,7 +110,7 @@ struct miopen_add
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[1].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
...@@ -159,18 +157,32 @@ struct miopen_gemm ...@@ -159,18 +157,32 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1.0f;
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))( float beta = 0.0f;
[&](auto output, auto input1, auto input2) { rocblas_int lda = args[0].get_shape().lens()[1];
dfor(input1.get_shape().lens()[0], rocblas_int ldb = args[1].get_shape().lens()[1];
input2.get_shape().lens()[1], rocblas_int ldc = args[2].get_shape().lens()[1];
input2.get_shape().lens()[0])( rocblas_int m = output_shape.lens()[0];
[&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); }); rocblas_int n = output_shape.lens()[1];
}); rocblas_int k = args[0].get_shape().lens()[1];
return to_gpu(result); rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none,
rocblas_operation_none,
n,
m,
k,
&alpha,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta,
args[2].implicit(),
ldc);
return args[2];
} }
}; };
...@@ -184,7 +196,7 @@ struct miopen_relu ...@@ -184,7 +196,7 @@ struct miopen_relu
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
......
...@@ -15,7 +15,8 @@ std::string miopen_target::name() const { return "miopen"; } ...@@ -15,7 +15,8 @@ std::string miopen_target::name() const { return "miopen"; }
context miopen_target::get_context() const context miopen_target::get_context() const
{ {
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))}; return miopen_context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr())};
} }
} // namespace miopen } // namespace miopen
......
#include <migraph/miopen/rocblas.hpp>
namespace migraph {
namespace miopen {
rocblas_handle_ptr create_rocblas_handle_ptr()
{
rocblas_handle handle;
rocblas_create_handle(&handle);
return rocblas_handle_ptr{handle};
}
} // namespace miopen
} // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment