ds_transformer_hip.h 6.03 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <hip/hip_runtime_api.h>
#include <hiprand/hiprand.h>
#include <memory>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "dropout_hip.h"
#include "feed_forward_hip.h"
#include "gelu_hip.h"
#include "general_kernels_hip.h"
#include "normalize_layer_hip.h"
#include "softmax_hip.h"
#include "strided_batch_gemm_hip.h"

struct BertGemmAlgos {
    int m_gemm_qkv_algo;
    int m_gemm_inter_algo;
    int m_gemm_output_algo;
    int m_gemm_batch1_algo;
    int m_gemm_batch2_algo;

    BertGemmAlgos()
        : m_gemm_qkv_algo(-1),
          m_gemm_inter_algo(-1),
          m_gemm_output_algo(-1),
          m_gemm_batch1_algo(-1),
          m_gemm_batch2_algo(-1)
    {
    }
};

template <typename T>
class BertTransformerLayer {
public:
    BertTransformerLayer(unsigned layer_id,
                         unsigned batch_size,
                         unsigned hidden_size,
                         unsigned num_heads,
                         unsigned intermediate_size,
                         unsigned seq_length,
                         float attn_dropout_ratio,
                         float hidden_output_dropout_ratio,
                         float layer_norm_eps,
                         bool pre_or_postLayerNorm,
                         const std::vector<std::array<int, 3>>& gemm_algos,
                         bool attn_dropout_checkpoint,
                         bool normalize_invertible,
                         bool gelu_checkpoint,
                         bool stochastic_mode);

    virtual ~BertTransformerLayer();

    void Forward(unsigned bsz,
                 const T* input_ptr,
                 const T* input_mask_ptr,
                 const T* attn_qkvw_ptr,
                 const T* attn_qkvb_ptr,
                 const T* attn_ow_ptr,
                 const T* attn_ob_ptr,
                 const T* attn_nw_ptr,
                 const T* attn_nb_ptr,
                 const T* inter_w_ptr,
                 const T* inter_b_ptr,
                 const T* output_w_ptr,
                 const T* output_b_ptr,
                 const T* norm_w_ptr,
                 const T* norm_b_ptr,
                 T* out_ptr,
                 T* inp_norm_ptr,
                 T* q_tf_ptr,
                 T* k_tf_ptr,
                 T* v_tf_ptr,
                 T* softmax_output_ptr,
                 T* ctx_bufB_ptr,
                 T* attn_o_inp_ptr,
                 T* add_res_ptr,
                 T* ff1_inp_ptr,
                 T* gelu_inp_ptr,
                 T* ff2_inp_ptr);

    void Backward(unsigned bsz,
                  const T* grad_output_ptr,
                  const T* input_ptr,
                  const T* output_ptr,
                  const T* inp_norm_ptr,
                  const T* q_tf_ptr,
                  const T* k_tf_ptr,
                  const T* v_tf_ptr,
                  const T* softmax_output_ptr,
                  const T* ctx_bufB_ptr,
                  const T* attn_o_inp_ptr,
                  const T* add_res_ptr,
                  const T* ff1_inp_ptr,
                  const T* gelu_inp_ptr,
                  const T* ff2_inp_ptr,
                  const T* input_mask_ptr,
                  const T* attn_qkvw_ptr,
                  const T* attn_ow_ptr,
                  const T* attn_nw_ptr,
                  const T* attn_nb_ptr,
                  const T* inter_w_ptr,
                  const T* inter_b_ptr,
                  const T* output_w_ptr,
                  const T* norm_w_ptr,
                  const T* norm_b_ptr,

                  T* grad_input_ptr,
                  T* grad_attn_qkvw_ptr,
                  T* grad_attn_qkvb_ptr,
                  T* grad_attn_ow_ptr,
                  T* grad_attn_ob_ptr,
                  T* grad_attn_nw_ptr,
                  T* grad_attn_nb_ptr,
                  T* grad_inter_w_ptr,
                  T* grad_inter_b_ptr,
                  T* grad_output_w_ptr,
                  T* grad_output_b_ptr,
                  T* grad_norm_w_ptr,
                  T* grad_norm_b_ptr);

    void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
                                uint8_t* attn_output_dropout_mask_ptr,
                                uint8_t* layer_output_dropout_mask_ptr,
                                T* layer_norm_var,
                                T* layer_norm_mean,
                                T* attn_layer_norm_var,
                                T* attn_layer_norm_mean);

    inline unsigned GetBatchSize() const { return _batch_size; }
    inline unsigned GetNumHeads() const { return _heads; }
    inline unsigned GetSeqLength() const { return _seq_length; }
    inline unsigned GetIntermediateSize() const { return _intermediate_size; }

    void SetSeqLength(unsigned seq_len);
    inline unsigned GetHiddenSize() const { return _hidden_size; }
    void SetTrainingMode(bool training);
    inline bool IsTrainingMode() const { return _training; }
    inline bool GeluCheckpoint() const { return _gelu_checkpoint; }

private:
    void Initialize();
    size_t getWorkspaceSize(int maxBatchSize) const;

    // Params
    unsigned _layer_id;
    unsigned _batch_size;
    unsigned _hidden_size;
    unsigned _heads;
    unsigned _size_per_head;
    unsigned _intermediate_size;
    unsigned _seq_length;

    bool _pre_or_postLayerNorm;

    rocblas_handle _cublasHandle;
    hipStream_t _stream;

    // layers
    FeedForward<T> _qkv_linear;
    FeedForward<T> _attn_out_linear;
    Normalize_Layer<T> _attn_layer_norm;
    Normalize_Layer<T> _layer_norm;
    Normalize_Layer<T>* _last_normalize;
    FeedForward<T> _ff1, _ff2;
    Softmax<T> _softmax;
    Gelu<T> _gelu;
    Dropout<T> _attn_prob_dropout;
    Dropout<T> _attn_output_dropout;
    Dropout<T> _layer_output_dropout;
    StridedBatchGemm<T> _attn_scores;
    StridedBatchGemm<T> _attn_context;

    bool _training;

    // Memory saving flags
    bool _attn_dropout_checkpoint;
    bool _normalize_invertible;
    bool _gelu_checkpoint;

    // High Performance flags
    bool _stochastic_mode;
};