ds_transformer_cuda.h 6.02 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#pragma once

#include <cuda_runtime_api.h>
#include <curand.h>
#include <memory>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "dropout.h"
#include "feed_forward.h"
#include "gelu.h"
#include "general_kernels.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"

struct BertGemmAlgos {
    int m_gemm_qkv_algo;
    int m_gemm_inter_algo;
    int m_gemm_output_algo;
    int m_gemm_batch1_algo;
    int m_gemm_batch2_algo;

    BertGemmAlgos()
        : m_gemm_qkv_algo(-1),
          m_gemm_inter_algo(-1),
          m_gemm_output_algo(-1),
          m_gemm_batch1_algo(-1),
          m_gemm_batch2_algo(-1)
    {
    }
};

template <typename T>
class BertTransformerLayer {
public:
aiss's avatar
aiss committed
42
43
44
45
46
47
    BertTransformerLayer(unsigned layer_id,
                         unsigned batch_size,
                         unsigned hidden_size,
                         unsigned num_heads,
                         unsigned intermediate_size,
                         unsigned seq_length,
48
49
                         float attn_dropout_ratio,
                         float hidden_output_dropout_ratio,
50
                         float layer_norm_eps,
51
52
53
54
55
56
57
58
59
                         bool pre_or_postLayerNorm,
                         const std::vector<std::array<int, 3>>& gemm_algos,
                         bool attn_dropout_checkpoint,
                         bool normalize_invertible,
                         bool gelu_checkpoint,
                         bool stochastic_mode);

    virtual ~BertTransformerLayer();

aiss's avatar
aiss committed
60
    void Forward(unsigned bsz,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                 const T* input_ptr,
                 const T* input_mask_ptr,
                 const T* attn_qkvw_ptr,
                 const T* attn_qkvb_ptr,
                 const T* attn_ow_ptr,
                 const T* attn_ob_ptr,
                 const T* attn_nw_ptr,
                 const T* attn_nb_ptr,
                 const T* inter_w_ptr,
                 const T* inter_b_ptr,
                 const T* output_w_ptr,
                 const T* output_b_ptr,
                 const T* norm_w_ptr,
                 const T* norm_b_ptr,
                 T* out_ptr,
                 T* inp_norm_ptr,
                 T* q_tf_ptr,
                 T* k_tf_ptr,
                 T* v_tf_ptr,
                 T* softmax_output_ptr,
                 T* ctx_bufB_ptr,
                 T* attn_o_inp_ptr,
                 T* add_res_ptr,
                 T* ff1_inp_ptr,
                 T* gelu_inp_ptr,
                 T* ff2_inp_ptr);

aiss's avatar
aiss committed
88
    void Backward(unsigned bsz,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                  const T* grad_output_ptr,
                  const T* input_ptr,
                  const T* output_ptr,
                  const T* inp_norm_ptr,
                  const T* q_tf_ptr,
                  const T* k_tf_ptr,
                  const T* v_tf_ptr,
                  const T* softmax_output_ptr,
                  const T* ctx_bufB_ptr,
                  const T* attn_o_inp_ptr,
                  const T* add_res_ptr,
                  const T* ff1_inp_ptr,
                  const T* gelu_inp_ptr,
                  const T* ff2_inp_ptr,
                  const T* input_mask_ptr,
                  const T* attn_qkvw_ptr,
                  const T* attn_ow_ptr,
                  const T* attn_nw_ptr,
                  const T* attn_nb_ptr,
                  const T* inter_w_ptr,
                  const T* inter_b_ptr,
                  const T* output_w_ptr,
                  const T* norm_w_ptr,
                  const T* norm_b_ptr,

                  T* grad_input_ptr,
                  T* grad_attn_qkvw_ptr,
                  T* grad_attn_qkvb_ptr,
                  T* grad_attn_ow_ptr,
                  T* grad_attn_ob_ptr,
                  T* grad_attn_nw_ptr,
                  T* grad_attn_nb_ptr,
                  T* grad_inter_w_ptr,
                  T* grad_inter_b_ptr,
                  T* grad_output_w_ptr,
                  T* grad_output_b_ptr,
                  T* grad_norm_w_ptr,
                  T* grad_norm_b_ptr);

    void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
                                uint8_t* attn_output_dropout_mask_ptr,
130
131
132
133
134
                                uint8_t* layer_output_dropout_mask_ptr,
                                T* layer_norm_var,
                                T* layer_norm_mean,
                                T* attn_layer_norm_var,
                                T* attn_layer_norm_mean);
135

aiss's avatar
aiss committed
136
137
138
139
    inline unsigned GetBatchSize() const { return _batch_size; }
    inline unsigned GetNumHeads() const { return _heads; }
    inline unsigned GetSeqLength() const { return _seq_length; }
    inline unsigned GetIntermediateSize() const { return _intermediate_size; }
140

aiss's avatar
aiss committed
141
142
    void SetSeqLength(unsigned seq_len);
    inline unsigned GetHiddenSize() const { return _hidden_size; }
143
    void SetTrainingMode(bool training);
144
145
    inline bool IsTrainingMode() const { return _training; }
    inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
146
147
148
149
150
151

private:
    void Initialize();
    size_t getWorkspaceSize(int maxBatchSize) const;

    // Params
aiss's avatar
aiss committed
152
153
154
155
156
157
158
    unsigned _layer_id;
    unsigned _batch_size;
    unsigned _hidden_size;
    unsigned _heads;
    unsigned _size_per_head;
    unsigned _intermediate_size;
    unsigned _seq_length;
159
160
161
162
163
164
165
166
167

    bool _pre_or_postLayerNorm;

    cublasHandle_t _cublasHandle;
    cudaStream_t _stream;

    // layers
    FeedForward<T> _qkv_linear;
    FeedForward<T> _attn_out_linear;
168
169
    Normalize_Layer<T> _attn_layer_norm;
    Normalize_Layer<T> _layer_norm;
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    Normalize_Layer<T>* _last_normalize;
    FeedForward<T> _ff1, _ff2;
    Softmax<T> _softmax;
    Gelu<T> _gelu;
    Dropout<T> _attn_prob_dropout;
    Dropout<T> _attn_output_dropout;
    Dropout<T> _layer_output_dropout;
    StridedBatchGemm<T> _attn_scores;
    StridedBatchGemm<T> _attn_context;

    bool _training;

    // Memory saving flags
    bool _attn_dropout_checkpoint;
    bool _normalize_invertible;
    bool _gelu_checkpoint;

aiss's avatar
aiss committed
187
    // High Performance flags
188
189
    bool _stochastic_mode;
};