Commit 35daafa7 authored by wsttiger's avatar wsttiger
Browse files

Initial changes to add rocblas for gemm calls

parent 9bc4ce27
......@@ -20,6 +20,10 @@ endif()
add_compile_options(-std=c++14)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc
......
......@@ -160,16 +160,31 @@ struct miopen_gemm
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
// visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
// [&](auto output, auto input1, auto input2) {
// dfor(input1.get_shape().lens()[0],
// input2.get_shape().lens()[1],
// input2.get_shape().lens()[0])(
// [&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); });
// });
visit_all(args[2], args[0], args[1])(
[&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1],
input2.get_shape().lens()[0])(
[&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); });
float alpha = 1.0;
float beta = 0.0;
rocblas_int lda = input1.get_shape().lens()[1];
rocblas_int ldb = input2.get_shape().lens()[1];
rocblas_int ldc = output.get_shape().lens()[1];
rocblas_int m = ouptut.get_shape().lens()[0];
rocblas_int n = ouptut.get_shape().lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle, rocblas_operation_none, rocblas_operation_none, n, m, k,
&alpha,
input2.data(), ldb,
input1.data(), lda,
&beta,
output.data(), ldc);
});
return to_gpu(result);
return args[2];
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment