gemm.cpp 3.35 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

Paul's avatar
Paul committed
8
template <class... Ts>
Paul's avatar
Paul committed
9
10
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
11
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
12
13
}

Paul's avatar
Paul committed
14
template <class... Ts>
Paul's avatar
Paul committed
15
16
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
17
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
18
19
}

Paul's avatar
Paul committed
20
template <class... Ts>
Paul's avatar
Paul committed
21
22
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
23
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
24
25
}

Paul's avatar
Paul committed
26
template <class T, class... Ts>
Paul's avatar
Paul committed
27
28
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
Paul's avatar
Paul committed
29
    MIGRAPHX_THROW("Type unsupported by rocblas");
Paul's avatar
Paul committed
30
31
}

Paul's avatar
Paul committed
32
template <class T>
Paul's avatar
Paul committed
33
34
35
36
37
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
38
template <class T>
Paul's avatar
Paul committed
39
40
41
42
43
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
44
template <>
Paul's avatar
Paul committed
45
46
47
48
49
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
50
template <class T>
Paul's avatar
Paul committed
51
52
53
54
55
56
57
58
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
59
template <class T>
Paul's avatar
Paul committed
60
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
61
{
Paul's avatar
Paul committed
62
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
63
64
}

Paul's avatar
Paul committed
65
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
66

wsttiger's avatar
wsttiger committed
67
68
69
70
71
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
72
73
74
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
75
{
Shucai Xiao's avatar
Shucai Xiao committed
76
77
78
79
    float alpha        = 1.0f;
    float beta         = 0.0f;
    bool transa        = args[0].get_shape().transposed();
    bool transb        = args[1].get_shape().transposed();
80
    std::size_t n_dims = args[0].get_shape().lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
81
82
83
84
85
86
87
88
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
    rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb    = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
    rocblas_int ldc    = args[2].get_shape().strides()[dim_0];
    rocblas_int m      = output_shape.lens()[dim_0];
    rocblas_int n      = output_shape.lens()[dim_1];
    rocblas_int k      = args[0].get_shape().lens()[dim_1];
Paul's avatar
Paul committed
89
    output_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
90
91
92
        auto alpha_r    = to_rocblas_type(as(alpha));
        auto beta_r     = to_rocblas_type(as(beta));
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
        generic_rocblas_gemm(as,
                             ctx.get_stream().get_rocblas(),
                             transb ? rocblas_operation_transpose : rocblas_operation_none,
                             transa ? rocblas_operation_transpose : rocblas_operation_none,
                             n,
                             m,
                             k,
                             &alpha_r,
Paul's avatar
Paul committed
101
                             to_pointer(args[1]),
Paul's avatar
Paul committed
102
                             ldb,
Paul's avatar
Paul committed
103
                             to_pointer(args[0]),
Paul's avatar
Paul committed
104
105
                             lda,
                             &beta_r,
Paul's avatar
Paul committed
106
                             to_pointer(args[2]),
Paul's avatar
Paul committed
107
108
                             ldc);

Paul's avatar
Paul committed
109
    });
110

wsttiger's avatar
wsttiger committed
111
112
113
114
    return args[2];
}

} // namespace gpu
Paul's avatar
Paul committed
115
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
116
} // namespace migraphx