ds_transformer_cuda.h 5.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#pragma once

#include <cuda_runtime_api.h>
#include <curand.h>
#include <memory>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "dropout.h"
#include "feed_forward.h"
#include "gelu.h"
#include "general_kernels.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"

struct BertGemmAlgos {
    int m_gemm_qkv_algo;
    int m_gemm_inter_algo;
    int m_gemm_output_algo;
    int m_gemm_batch1_algo;
    int m_gemm_batch2_algo;

    BertGemmAlgos()
        : m_gemm_qkv_algo(-1),
          m_gemm_inter_algo(-1),
          m_gemm_output_algo(-1),
          m_gemm_batch1_algo(-1),
          m_gemm_batch2_algo(-1)
    {
    }
};

template <typename T>
class BertTransformerLayer {
public:
    BertTransformerLayer(int layer_id,
                         int batch_size,
                         int hidden_size,
                         int num_heads,
                         int intermediate_size,
                         int seq_length,
                         float attn_dropout_ratio,
                         float hidden_output_dropout_ratio,
                         bool pre_or_postLayerNorm,
                         const std::vector<std::array<int, 3>>& gemm_algos,
                         bool attn_dropout_checkpoint,
                         bool normalize_invertible,
                         bool gelu_checkpoint,
                         bool stochastic_mode);

    virtual ~BertTransformerLayer();

    void Forward(int bsz,
                 const T* input_ptr,
                 const T* input_mask_ptr,
                 const T* attn_qkvw_ptr,
                 const T* attn_qkvb_ptr,
                 const T* attn_ow_ptr,
                 const T* attn_ob_ptr,
                 const T* attn_nw_ptr,
                 const T* attn_nb_ptr,
                 const T* inter_w_ptr,
                 const T* inter_b_ptr,
                 const T* output_w_ptr,
                 const T* output_b_ptr,
                 const T* norm_w_ptr,
                 const T* norm_b_ptr,
                 T* out_ptr,
                 T* inp_norm_ptr,
                 T* q_tf_ptr,
                 T* k_tf_ptr,
                 T* v_tf_ptr,
                 T* softmax_output_ptr,
                 T* ctx_bufB_ptr,
                 T* attn_o_inp_ptr,
                 T* add_res_ptr,
                 T* ff1_inp_ptr,
                 T* gelu_inp_ptr,
                 T* ff2_inp_ptr);

    void Backward(int bsz,
                  const T* grad_output_ptr,
                  const T* input_ptr,
                  const T* output_ptr,
                  const T* inp_norm_ptr,
                  const T* q_tf_ptr,
                  const T* k_tf_ptr,
                  const T* v_tf_ptr,
                  const T* softmax_output_ptr,
                  const T* ctx_bufB_ptr,
                  const T* attn_o_inp_ptr,
                  const T* add_res_ptr,
                  const T* ff1_inp_ptr,
                  const T* gelu_inp_ptr,
                  const T* ff2_inp_ptr,
                  const T* input_mask_ptr,
                  const T* attn_qkvw_ptr,
                  const T* attn_ow_ptr,
                  const T* attn_nw_ptr,
                  const T* attn_nb_ptr,
                  const T* inter_w_ptr,
                  const T* inter_b_ptr,
                  const T* output_w_ptr,
                  const T* norm_w_ptr,
                  const T* norm_b_ptr,

                  T* grad_input_ptr,
                  T* grad_attn_qkvw_ptr,
                  T* grad_attn_qkvb_ptr,
                  T* grad_attn_ow_ptr,
                  T* grad_attn_ob_ptr,
                  T* grad_attn_nw_ptr,
                  T* grad_attn_nb_ptr,
                  T* grad_inter_w_ptr,
                  T* grad_inter_b_ptr,
                  T* grad_output_w_ptr,
                  T* grad_output_b_ptr,
                  T* grad_norm_w_ptr,
                  T* grad_norm_b_ptr);

    void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
                                uint8_t* attn_output_dropout_mask_ptr,
124
125
126
127
128
                                uint8_t* layer_output_dropout_mask_ptr,
                                T* layer_norm_var,
                                T* layer_norm_mean,
                                T* attn_layer_norm_var,
                                T* attn_layer_norm_mean);
129
130
131
132

    inline int GetBatchSize() const { return _batch_size; }
    inline int GetNumHeads() const { return _heads; }
    inline int GetSeqLength() const { return _seq_length; }
133
    inline int GetIntermediateSize() const { return _intermediate_size; }
134

135
    void SetSeqLength(int seq_len);
136
137
    inline int GetHiddenSize() const { return _hidden_size; }
    void SetTrainingMode(bool training);
138
139
    inline bool IsTrainingMode() const { return _training; }
    inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

private:
    void Initialize();
    size_t getWorkspaceSize(int maxBatchSize) const;

    // Params
    int _layer_id;
    int _batch_size;
    int _hidden_size;
    int _heads;
    int _size_per_head;
    int _intermediate_size;
    int _seq_length;

    bool _pre_or_postLayerNorm;

    cublasHandle_t _cublasHandle;
    cudaStream_t _stream;

    // layers
    FeedForward<T> _qkv_linear;
    FeedForward<T> _attn_out_linear;
162
163
    Normalize_Layer<T> _attn_layer_norm;
    Normalize_Layer<T> _layer_norm;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    Normalize_Layer<T>* _last_normalize;
    FeedForward<T> _ff1, _ff2;
    Softmax<T> _softmax;
    Gelu<T> _gelu;
    Dropout<T> _attn_prob_dropout;
    Dropout<T> _attn_output_dropout;
    Dropout<T> _layer_output_dropout;
    StridedBatchGemm<T> _attn_scores;
    StridedBatchGemm<T> _attn_context;

    bool _training;

    // Memory saving flags
    bool _attn_dropout_checkpoint;
    bool _normalize_invertible;
    bool _gelu_checkpoint;

    // High Performace flags
    bool _stochastic_mode;
};