Commit 23354075 authored by wsttiger's avatar wsttiger
Browse files

gemm from rocblas is working

parent 35daafa7
...@@ -45,6 +45,11 @@ struct shape ...@@ -45,6 +45,11 @@ struct shape
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE) MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE #undef MIGRAPH_SHAPE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
{
};
shape(); shape();
shape(type_t t); shape(type_t t);
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
......
...@@ -11,5 +11,5 @@ add_library(migraph_miopen ...@@ -11,5 +11,5 @@ add_library(migraph_miopen
miopen_target.cpp miopen_target.cpp
) )
rocm_clang_tidy_check(migraph_miopen) rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen) target_link_libraries(migraph_miopen migraph MIOpen rocblas)
target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
#include <rocblas.h>
#include <migraph/miopen/miopen_target.hpp> #include <migraph/miopen/miopen_target.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
...@@ -160,30 +161,30 @@ struct miopen_gemm ...@@ -160,30 +161,30 @@ struct miopen_gemm
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
// visit_all(result, from_gpu(args[0]), from_gpu(args[1]))( rocblas_handle rochandle;
// [&](auto output, auto input1, auto input2) { rocblas_create_handle(&rochandle);
// dfor(input1.get_shape().lens()[0], float alpha = 1.0f;
// input2.get_shape().lens()[1], float beta = 0.0f;
// input2.get_shape().lens()[0])( rocblas_int lda = args[0].get_shape().lens()[1];
// [&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); }); rocblas_int ldb = args[1].get_shape().lens()[1];
// }); rocblas_int ldc = args[2].get_shape().lens()[1];
visit_all(args[2], args[0], args[1])( rocblas_int m = output_shape.lens()[0];
[&](auto output, auto input1, auto input2) { rocblas_int n = output_shape.lens()[1];
float alpha = 1.0; rocblas_int k = args[0].get_shape().lens()[1];
float beta = 0.0; rocblas_sgemm(rochandle,
rocblas_int lda = input1.get_shape().lens()[1]; rocblas_operation_none,
rocblas_int ldb = input2.get_shape().lens()[1]; rocblas_operation_none,
rocblas_int ldc = output.get_shape().lens()[1]; n,
rocblas_int m = ouptut.get_shape().lens()[0]; m,
rocblas_int n = ouptut.get_shape().lens()[1]; k,
rocblas_int k = args[0].get_shape().lens()[1]; &alpha,
rocblas_sgemm(rochandle, rocblas_operation_none, rocblas_operation_none, n, m, k, args[1].implicit(),
&alpha, ldb,
input2.data(), ldb, args[0].implicit(),
input1.data(), lda, lda,
&beta, &beta,
output.data(), ldc); args[2].implicit(),
}); ldc);
return args[2]; return args[2];
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment