Commit 23354075 authored by wsttiger's avatar wsttiger
Browse files

gemm from rocblas is working

parent 35daafa7
......@@ -45,6 +45,11 @@ struct shape
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
{
};
shape();
shape(type_t t);
shape(type_t t, std::vector<std::size_t> l);
......
......@@ -11,5 +11,5 @@ add_library(migraph_miopen
miopen_target.cpp
)
rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen)
target_link_libraries(migraph_miopen migraph MIOpen rocblas)
target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
#include <rocblas.h>
#include <migraph/miopen/miopen_target.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
......@@ -160,30 +161,30 @@ struct miopen_gemm
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
// visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
// [&](auto output, auto input1, auto input2) {
// dfor(input1.get_shape().lens()[0],
// input2.get_shape().lens()[1],
// input2.get_shape().lens()[0])(
// [&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); });
// });
visit_all(args[2], args[0], args[1])(
[&](auto output, auto input1, auto input2) {
float alpha = 1.0;
float beta = 0.0;
rocblas_int lda = input1.get_shape().lens()[1];
rocblas_int ldb = input2.get_shape().lens()[1];
rocblas_int ldc = output.get_shape().lens()[1];
rocblas_int m = ouptut.get_shape().lens()[0];
rocblas_int n = ouptut.get_shape().lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle, rocblas_operation_none, rocblas_operation_none, n, m, k,
&alpha,
input2.data(), ldb,
input1.data(), lda,
&beta,
output.data(), ldc);
});
rocblas_handle rochandle;
rocblas_create_handle(&rochandle);
float alpha = 1.0f;
float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1];
rocblas_int ldb = args[1].get_shape().lens()[1];
rocblas_int ldc = args[2].get_shape().lens()[1];
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(rochandle,
rocblas_operation_none,
rocblas_operation_none,
n,
m,
k,
&alpha,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta,
args[2].implicit(),
ldc);
return args[2];
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment