Commit 35daafa7 authored by wsttiger's avatar wsttiger
Browse files

Initial changes to add rocblas for gemm calls

parent 9bc4ce27
...@@ -20,6 +20,10 @@ endif() ...@@ -20,6 +20,10 @@ endif()
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc # Override clang-tidy to not find the version from hcc
......
...@@ -160,16 +160,31 @@ struct miopen_gemm ...@@ -160,16 +160,31 @@ struct miopen_gemm
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; // visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
// [&](auto output, auto input1, auto input2) {
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))( // dfor(input1.get_shape().lens()[0],
// input2.get_shape().lens()[1],
// input2.get_shape().lens()[0])(
// [&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); });
// });
visit_all(args[2], args[0], args[1])(
[&](auto output, auto input1, auto input2) { [&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0], float alpha = 1.0;
input2.get_shape().lens()[1], float beta = 0.0;
input2.get_shape().lens()[0])( rocblas_int lda = input1.get_shape().lens()[1];
[&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); }); rocblas_int ldb = input2.get_shape().lens()[1];
rocblas_int ldc = output.get_shape().lens()[1];
rocblas_int m = ouptut.get_shape().lens()[0];
rocblas_int n = ouptut.get_shape().lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle, rocblas_operation_none, rocblas_operation_none, n, m, k,
&alpha,
input2.data(), ldb,
input1.data(), lda,
&beta,
output.data(), ldc);
}); });
return to_gpu(result); return args[2];
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment