"src/vscode:/vscode.git/clone" did not exist on "458ec14999c8aa68ede8ffa2c979ee58ef45e067"
gemm.cpp 3.25 KB
Newer Older
wsttiger's avatar
wsttiger committed
1
2
3
4
5
6
7
#include <migraph/gpu/gemm.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <utility>

namespace migraph {
8
inline namespace MIGRAPH_INLINE_NS {
wsttiger's avatar
wsttiger committed
9
10
namespace gpu {

Paul's avatar
Paul committed
11
template <class... Ts>
Paul's avatar
Paul committed
12
13
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
14
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
15
16
}

Paul's avatar
Paul committed
17
template <class... Ts>
Paul's avatar
Paul committed
18
19
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
20
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
21
22
}

Paul's avatar
Paul committed
23
template <class... Ts>
Paul's avatar
Paul committed
24
25
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
26
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
27
28
}

Paul's avatar
Paul committed
29
template <class T, class... Ts>
Paul's avatar
Paul committed
30
31
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
Paul's avatar
Paul committed
32
    MIGRAPH_THROW("Type unsupported by rocblas");
Paul's avatar
Paul committed
33
34
}

Paul's avatar
Paul committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
template<class T>
struct compute_rocblas_type
{
    using type = T;
};

template<class T>
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

template<>
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

template<class T>
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
62
template <class T>
Paul's avatar
Paul committed
63
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
64
{
Paul's avatar
Paul committed
65
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
66
67
}

Paul's avatar
Paul committed
68
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
69

wsttiger's avatar
wsttiger committed
70
71
72
73
74
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
75
76
77
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
78
79
80
81
82
83
84
85
86
87
88
{
    float alpha     = 1.0f;
    float beta      = 0.0f;
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
    rocblas_int ldc = args[2].get_shape().strides()[0];
    rocblas_int m   = output_shape.lens()[0];
    rocblas_int n   = output_shape.lens()[1];
    rocblas_int k   = args[0].get_shape().lens()[1];
Paul's avatar
Paul committed
89
    output_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
90
91
        auto alpha_r = to_rocblas_type(as(alpha));
        auto beta_r  = to_rocblas_type(as(beta));
Paul's avatar
Paul committed
92
93
94
        auto to_pointer = [&](auto&& arg) {
            return to_rocblas_type(as.from(arg.data()));
        };
Paul's avatar
Paul committed
95
96
97
98
99
100
101
102
        generic_rocblas_gemm(as,
                             ctx.get_stream().get_rocblas(),
                             transb ? rocblas_operation_transpose : rocblas_operation_none,
                             transa ? rocblas_operation_transpose : rocblas_operation_none,
                             n,
                             m,
                             k,
                             &alpha_r,
Paul's avatar
Paul committed
103
                             to_pointer(args[1]),
Paul's avatar
Paul committed
104
                             ldb,
Paul's avatar
Paul committed
105
                             to_pointer(args[0]),
Paul's avatar
Paul committed
106
107
                             lda,
                             &beta_r,
Paul's avatar
Paul committed
108
                             to_pointer(args[2]),
Paul's avatar
Paul committed
109
110
                             ldc);

Paul's avatar
Paul committed
111
    });
wsttiger's avatar
wsttiger committed
112
113
114
115
    return args[2];
}

} // namespace gpu
116
} // namespace MIGRAPH_INLINE_NS
wsttiger's avatar
wsttiger committed
117
} // namespace migraph