gemm.cpp 2.58 KB
Newer Older
wsttiger's avatar
wsttiger committed
1
2
3
4
5
6
7
8
9
#include <migraph/gpu/gemm.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <utility>

namespace migraph {
namespace gpu {

Paul's avatar
Paul committed
10
template <class... Ts>
Paul's avatar
Paul committed
11
12
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
13
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
14
15
}

Paul's avatar
Paul committed
16
template <class... Ts>
Paul's avatar
Paul committed
17
18
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
19
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
20
21
}

Paul's avatar
Paul committed
22
template <class... Ts>
Paul's avatar
Paul committed
23
24
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
25
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
26
27
}

Paul's avatar
Paul committed
28
template <class T, class... Ts>
Paul's avatar
Paul committed
29
30
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
Paul's avatar
Paul committed
31
    MIGRAPH_THROW("Type unsupported by rocblas");
Paul's avatar
Paul committed
32
33
}

Paul's avatar
Paul committed
34
template <class T>
Paul's avatar
Paul committed
35
36
T to_rocblas_type(T x)
{
Paul's avatar
Paul committed
37
    return x;
Paul's avatar
Paul committed
38
39
}

Paul's avatar
Paul committed
40
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
41

wsttiger's avatar
wsttiger committed
42
43
44
45
46
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
47
48
49
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
50
51
52
53
54
55
56
57
58
59
60
{
    float alpha     = 1.0f;
    float beta      = 0.0f;
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
    rocblas_int ldc = args[2].get_shape().strides()[0];
    rocblas_int m   = output_shape.lens()[0];
    rocblas_int n   = output_shape.lens()[1];
    rocblas_int k   = args[0].get_shape().lens()[1];
Paul's avatar
Paul committed
61
    output_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        auto alpha_r = to_rocblas_type(as(alpha));
        auto beta_r  = to_rocblas_type(as(beta));
        generic_rocblas_gemm(as,
                             ctx.get_stream().get_rocblas(),
                             transb ? rocblas_operation_transpose : rocblas_operation_none,
                             transa ? rocblas_operation_transpose : rocblas_operation_none,
                             n,
                             m,
                             k,
                             &alpha_r,
                             args[1].implicit(),
                             ldb,
                             args[0].implicit(),
                             lda,
                             &beta_r,
                             args[2].implicit(),
                             ldc);

Paul's avatar
Paul committed
80
    });
wsttiger's avatar
wsttiger committed
81
82
83
84
85
86
    return args[2];
}

} // namespace gpu

} // namespace migraph