gemm.cpp 4.81 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
32
template <class... Ts>
Paul's avatar
Paul committed
33
34
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
35
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
36
37
}

Paul's avatar
Paul committed
38
template <class... Ts>
Paul's avatar
Paul committed
39
40
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
41
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
42
43
}

Paul's avatar
Paul committed
44
template <class... Ts>
Paul's avatar
Paul committed
45
46
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
47
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
48
49
}

Paul's avatar
Paul committed
50
template <class T, class... Ts>
Paul's avatar
Paul committed
51
52
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
53
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
54
55
}

Paul's avatar
Paul committed
56
template <class T>
Paul's avatar
Paul committed
57
58
59
60
61
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
62
template <class T>
Paul's avatar
Paul committed
63
64
65
66
67
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
68
template <>
Paul's avatar
Paul committed
69
70
71
72
73
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
74
template <class T>
Paul's avatar
Paul committed
75
76
77
78
79
80
81
82
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
83
template <class T>
Paul's avatar
Paul committed
84
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
85
{
Paul's avatar
Paul committed
86
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
87
88
}

Paul's avatar
Paul committed
89
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
90

wsttiger's avatar
wsttiger committed
91
92
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
93
    return op.compute_shape(inputs);
wsttiger's avatar
wsttiger committed
94
}
wsttiger's avatar
wsttiger committed
95
96
97
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
98
{
Shucai Xiao's avatar
Shucai Xiao committed
99
100
    bool transa        = args[0].get_shape().transposed();
    bool transb        = args[1].get_shape().transposed();
101
    std::size_t n_dims = args[0].get_shape().lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
102
103
104
105
106
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
    rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb    = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
    rocblas_int ldc    = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
107
    auto out_lens      = output_shape.lens();
108
109
    rocblas_int m      = out_lens[dim_0];
    rocblas_int n      = out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
110
    rocblas_int k      = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
111
112
    auto batch_num     = std::accumulate(
        out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
113
114
115
116

    bool is_3inputs = (args.size() == 4);
    output_shape.visit_type([&](auto as) {
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
117
118
119
120
121
        if(is_3inputs)
            hipMemcpy(to_pointer(args[3]),
                      to_pointer(args[2]),
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
122
123
124
125
        else
            hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
    });

Paul's avatar
Paul committed
126
    output_shape.visit_type([&](auto as) {
127
128
        auto alpha_r    = to_rocblas_type(as(op.alpha));
        auto beta_r     = to_rocblas_type(as(op.beta));
Paul's avatar
Paul committed
129
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        generic_rocblas_batched_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1]),
                                     ldb,
                                     k * n,
                                     to_pointer(args[0]),
                                     lda,
                                     m * k,
                                     &beta_r,
                                     to_pointer(args[2]),
                                     ldc,
                                     m * n,
                                     batch_num);
Paul's avatar
Paul committed
149
    });
150

151
    return (is_3inputs ? args[3] : args[2]);
wsttiger's avatar
wsttiger committed
152
153
154
}

} // namespace gpu
Paul's avatar
Paul committed
155
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
156
} // namespace migraphx