strided_batch_gemm.h 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#pragma once

#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>

template <typename T>
class StridedBatchGemm {
public:
    struct Config {
        int batch_size;
        int m;
        int n;
        int k;
        float alpha;
        float beta;
        cublasOperation_t op_A;
        cublasOperation_t op_B;
        std::array<int, 3> gemm_algos;

        Config(int batch,
               int mm,
               int nn,
               int kk,
               float param_alpha,
               float param_beta,
               cublasOperation_t opA,
               cublasOperation_t opB,
               const std::array<int, 3>& algos)
            : batch_size(batch),
              m(mm),
              n(nn),
              k(kk),
              alpha(param_alpha),
              beta(param_beta),
              op_A(opA),
              op_B(opB),
              gemm_algos(algos)
        {
        }
    };

    StridedBatchGemm(const Config& config) : _config(config) {}

    virtual ~StridedBatchGemm() {}

    void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
    {
        int stride_a = _config.m * _config.k;
        int stride_b = _config.n * _config.k;
        int stride_c = _config.m * _config.n;

        cublas_strided_batched_gemm(handle,
                                    _config.m,
                                    _config.n,
                                    _config.k,
                                    &_config.alpha,
                                    &_config.beta,
                                    _buffer_a,
                                    _buffer_b,
                                    output,
                                    _config.op_A,
                                    _config.op_B,
                                    stride_a,
                                    stride_b,
                                    stride_c,
                                    bsz,
                                    cublasGemmAlgo_t(_config.gemm_algos[0]));
    }

    void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
    {
        int stride_a = _config.m * _config.k;
        int stride_b = _config.n * _config.k;
        int stride_c = _config.m * _config.n;

        cublas_strided_batched_gemm(handle,
                                    _config.m,
                                    _config.n,
                                    _config.k,
                                    &_config.alpha,
                                    &_config.beta,
                                    _buffer_a,
                                    _buffer_b,
                                    output,
                                    _config.op_A,
                                    _config.op_B,
                                    stride_a,
                                    stride_b,
                                    stride_c,
                                    _config.batch_size,
                                    cublasGemmAlgo_t(_config.gemm_algos[0]));

        k_buf = _buffer_a;
        q_buf = _buffer_b;
    }

    void Backward(int bsz,
                  const T* d_output,
                  const T* _buffer_a,
                  const T* _buffer_b,
                  cublasHandle_t handle,
                  T* inpGradA = nullptr,
                  T* inpGradB = nullptr)
    {
        int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
        int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);

        int stride_a = mb * _config.n;
        int stride_b = _config.n * kb;
        int stride_c = _config.m * _config.k;

        // B need to transpose.
        cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);

        // Calculate d_A.
        cublas_strided_batched_gemm(handle,
                                    mb,
                                    kb,
                                    _config.n,
                                    &_config.alpha,
                                    &_config.beta,
                                    (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
                                    (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b),
                                    inpGradA,
                                    CUBLAS_OP_N,
                                    op_b,
                                    stride_a,
                                    stride_b,
                                    stride_c,
                                    bsz,
                                    cublasGemmAlgo_t(_config.gemm_algos[1]));

        // A need to transpose.
        cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);

        stride_a = _config.m * _config.k;
        stride_b = _config.m * _config.n;
        stride_c = _config.n * _config.k;

        // Calculate d_B.
        cublas_strided_batched_gemm(handle,
                                    _config.k,
                                    _config.n,
                                    _config.m,
                                    &_config.alpha,
                                    &_config.beta,
                                    _buffer_a,
                                    d_output,
                                    inpGradB,
                                    op_a,
                                    CUBLAS_OP_N,
                                    stride_a,
                                    stride_b,
                                    stride_c,
                                    bsz,
                                    cublasGemmAlgo_t(_config.gemm_algos[2]));
    }

    inline int GetN() const { return _config.k; }

    inline const T* GetBufferA() const { return k_buf; }

    inline const T* GetBufferB() const { return q_buf; }

private:
    Config _config;
    const T* q_buf;
    const T* k_buf;
};