Commit 42e67e3d authored by Paul's avatar Paul
Browse files

Add half support to gemm

parent f3ddd797
......@@ -7,6 +7,41 @@
namespace migraph {
namespace gpu {
template<class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template<class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
}
template<class T>
T to_rocblas_type(T x)
{
return x;
}
rocblas_half to_rocblas_type(half x)
{
return reinterpret_cast<const rocblas_half&>(x);
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
......@@ -26,20 +61,25 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.get_stream().get_rocblas(),
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
&alpha_r,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta,
&beta_r,
args[2].implicit(),
ldc);
});
return args[2];
}
......
......@@ -493,6 +493,18 @@ struct test_gemm
}
};
struct test_gemm_half
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::half_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::half_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b);
return p;
}
};
struct test_gemm_ld
{
migraph::program create_program() const
......@@ -761,6 +773,7 @@ int main()
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment