".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "74630ba3109197a3f5238592af36eba1c97e77bd"
Commit 42e67e3d authored by Paul's avatar Paul
Browse files

Add half support to gemm

parent f3ddd797
...@@ -7,6 +7,41 @@ ...@@ -7,6 +7,41 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
template<class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template<class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
}
template<class T>
T to_rocblas_type(T x)
{
return x;
}
rocblas_half to_rocblas_type(half x)
{
return reinterpret_cast<const rocblas_half&>(x);
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
...@@ -26,20 +61,25 @@ argument miopen_gemm::compute(context& ctx, ...@@ -26,20 +61,25 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = output_shape.lens()[0]; rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.get_stream().get_rocblas(), output_shape.visit_type([&](auto as) {
transb ? rocblas_operation_transpose : rocblas_operation_none, auto alpha_r = to_rocblas_type(as(alpha));
transa ? rocblas_operation_transpose : rocblas_operation_none, auto beta_r = to_rocblas_type(as(beta));
n, generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(),
m, transb ? rocblas_operation_transpose : rocblas_operation_none,
k, transa ? rocblas_operation_transpose : rocblas_operation_none,
&alpha, n,
args[1].implicit(), m,
ldb, k,
args[0].implicit(), &alpha_r,
lda, args[1].implicit(),
&beta, ldb,
args[2].implicit(), args[0].implicit(),
ldc); lda,
&beta_r,
args[2].implicit(),
ldc);
});
return args[2]; return args[2];
} }
......
...@@ -493,6 +493,18 @@ struct test_gemm ...@@ -493,6 +493,18 @@ struct test_gemm
} }
}; };
struct test_gemm_half
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::half_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::half_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b);
return p;
}
};
struct test_gemm_ld struct test_gemm_ld
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -761,6 +773,7 @@ int main() ...@@ -761,6 +773,7 @@ int main()
verify_program<test_global_avg_pooling>(); verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>(); verify_program<test_global_max_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>(); // verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>(); verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>(); verify_program<test_gemm_transposea>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment