Commit 42e67e3d authored by Paul's avatar Paul
Browse files

Add half support to gemm

parent f3ddd797
...@@ -7,6 +7,41 @@ ...@@ -7,6 +7,41 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
template<class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template<class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
}
template<class T>
T to_rocblas_type(T x)
{
return x;
}
rocblas_half to_rocblas_type(half x)
{
return reinterpret_cast<const rocblas_half&>(x);
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
...@@ -26,20 +61,25 @@ argument miopen_gemm::compute(context& ctx, ...@@ -26,20 +61,25 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = output_shape.lens()[0]; rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.get_stream().get_rocblas(), output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha, &alpha_r,
args[1].implicit(), args[1].implicit(),
ldb, ldb,
args[0].implicit(), args[0].implicit(),
lda, lda,
&beta, &beta_r,
args[2].implicit(), args[2].implicit(),
ldc); ldc);
});
return args[2]; return args[2];
} }
......
...@@ -493,6 +493,18 @@ struct test_gemm ...@@ -493,6 +493,18 @@ struct test_gemm
} }
}; };
struct test_gemm_half
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::half_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::half_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b);
return p;
}
};
struct test_gemm_ld struct test_gemm_ld
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -761,6 +773,7 @@ int main() ...@@ -761,6 +773,7 @@ int main()
verify_program<test_global_avg_pooling>(); verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>(); verify_program<test_global_max_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>(); // verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>(); verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>(); verify_program<test_gemm_transposea>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment