ds_transformer_cuda.cpp 46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <torch/extension.h>

#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"

static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;

17
18
const int init_seq_length = 128;

19
20
21
22
23
24
// C++ interface

template <typename T>
size_t get_workspace_size(int maxBatchSize,
                          int seq_len,
                          int hidden_size,
25
                          int intermediate_size,
26
27
28
29
30
31
                          int heads,
                          bool training,
                          bool gelu_checkpoint)
{
    size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
    if (training) {
32
33
        workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
                                     2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
34
35
        if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
    }
36
    return workSpacesize;  // * sizeof(T);
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
}

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)

template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
                                              int batch_size,
                                              int hidden_size,
                                              int num_heads,
                                              int intermediate_size,
                                              int seq_length,
                                              float attn_prob_dropout_ratio,
                                              float hidden_output_dropout_ratio,
55
                                              float layer_norm_eps,
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                                              bool pre_or_postLayerNorm,
                                              const std::vector<std::array<int, 3>>& gemm_algos,
                                              bool attn_dropout_checkpoint,
                                              bool normalize_invertible,
                                              bool gelu_checkpoint,
                                              bool stochastic_mode)
    : _layer_id(layer_id),
      _batch_size(batch_size),
      _hidden_size(hidden_size),
      _heads(num_heads),
      _intermediate_size(intermediate_size),
      _seq_length(seq_length),
      _training(true),
      _pre_or_postLayerNorm(pre_or_postLayerNorm),
      _attn_dropout_checkpoint(attn_dropout_checkpoint),
      _normalize_invertible(normalize_invertible),
      _gelu_checkpoint(gelu_checkpoint),
      _stochastic_mode(stochastic_mode),
      _stream(Context::Instance().GetCurrentStream()),
      _cublasHandle(Context::Instance().GetCublasHandle()),
      _qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
                                                  3 * hidden_size,
                                                  hidden_size,
                                                  gemm_algos[0])),
      _attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
                                                       hidden_size,
                                                       hidden_size,
                                                       gemm_algos[0])),
84
85
86
      _attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
                                                           seq_length,
                                                           hidden_size,
87
                                                           layer_norm_eps,
88
89
90
91
92
                                                           true,
                                                           !normalize_invertible)),
      _layer_norm(typename Normalize_Layer<T>::Config(batch_size,
                                                      seq_length,
                                                      hidden_size,
93
                                                      layer_norm_eps,
94
95
                                                      true,
                                                      !normalize_invertible)),
96
      _ff1(typename FeedForward<T>::Config(batch_size * seq_length,
97
                                           _intermediate_size,
98
99
100
101
                                           hidden_size,
                                           gemm_algos[1])),
      _ff2(typename FeedForward<T>::Config(batch_size * seq_length,
                                           hidden_size,
102
                                           _intermediate_size,
103
104
                                           gemm_algos[2])),
      _softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
105
106
107
108
      _gelu(typename Gelu<T>::Config(_intermediate_size)),
      _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
      _attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
      _layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
      _attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
                                                        _seq_length,
                                                        _seq_length,
                                                        _hidden_size / _heads,
                                                        (T(1.0) / T(sqrt(_hidden_size / _heads))),
                                                        T(0.0),
                                                        CUBLAS_OP_T,
                                                        CUBLAS_OP_N,
                                                        gemm_algos[3])),
      _attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
                                                         _hidden_size / _heads,
                                                         _seq_length,
                                                         _seq_length,
                                                         T(1.0),
                                                         T(0.0),
                                                         CUBLAS_OP_N,
                                                         CUBLAS_OP_N,
                                                         gemm_algos[4]))
{
    assert(_hidden_size % _heads == 0);

    Initialize();
}

template <typename T>
BertTransformerLayer<T>::~BertTransformerLayer()
{
}

template <typename T>
void BertTransformerLayer<T>::Initialize()
{
    if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}

template <typename T>
void BertTransformerLayer<T>::Forward(int bsz,
                                      const T* input_ptr,
                                      const T* input_mask_ptr,
                                      const T* attn_qkvw_ptr,
                                      const T* attn_qkvb_ptr,
                                      const T* attn_ow_ptr,
                                      const T* attn_ob_ptr,
                                      const T* attn_nw_ptr,
                                      const T* attn_nb_ptr,
                                      const T* inter_w_ptr,
                                      const T* inter_b_ptr,
                                      const T* output_w_ptr,
                                      const T* output_b_ptr,
                                      const T* norm_w_ptr,
                                      const T* norm_b_ptr,
                                      T* out_ptr,
                                      T* inp_norm_ptr,
                                      T* q_tf_ptr,
                                      T* k_tf_ptr,
                                      T* v_tf_ptr,
                                      T* soft_out_ptr,
                                      T* ctx_bufB_ptr,
                                      T* attn_o_inp_ptr,
                                      T* add_res_ptr,
                                      T* ff1_inp_ptr,
                                      T* gelu_inp_ptr,
                                      T* ff2_inp_ptr)
{
    cublasSetStream(_cublasHandle, _stream);

    if (!_stochastic_mode) cudaStreamSynchronize(_stream);

    T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
    size_t small_buf_size = bsz * _seq_length * _hidden_size;
    T* buf_0 = workspace;
    T* buf_1 = buf_0 + small_buf_size;

    if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
    if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;

185
186
    int bsz_seq = bsz * _seq_length;

187
    if (_pre_or_postLayerNorm) {
188
189
190
        if (_layer_norm.UseMean())
            _layer_norm.ForwardCheckpoint(
                bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
191
192

        else
193
194
            _layer_norm.Forward(
                bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    }

    if (_pre_or_postLayerNorm)
        _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
    else
        _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);

    launch_bias_add_transform_0213<T>(
        q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);

    int bsz_heads = bsz * _heads;

    // attention scores
    _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);

    // Softmax + Mask
    _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);

    // attn prob dropout.
    _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);

    // attention context
    _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);

    launch_transform4d_0213<T>(
        attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);

    if (_pre_or_postLayerNorm)
        _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
    else
        _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);

    // attn output dropout.
    if (_pre_or_postLayerNorm)
        _attn_output_dropout.ForwardWithBias(
            bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
    else
        _attn_output_dropout.ForwardWithBias(
            bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);

    if (_pre_or_postLayerNorm) {
236
237
238
        if (_attn_layer_norm.UseMean())
            _attn_layer_norm.ForwardCheckpoint(
                bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
239
        else
240
241
            _attn_layer_norm.Forward(
                bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
242
    } else {
243
244
245
        if (_attn_layer_norm.UseMean())
            _attn_layer_norm.ForwardCheckpoint(
                bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
246
        else
247
248
            _attn_layer_norm.Forward(
                bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
249
250
251
252
253
254
255
256
    }

    _ff1.Forward(bsz_seq,
                 ff1_inp_ptr,
                 inter_w_ptr,
                 (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
                 _cublasHandle);

257
    _gelu.ForwardWithBiasAdd(bsz_seq,
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                             (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
                             inter_b_ptr,
                             (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
                             _stream);

    _ff2.Forward(bsz_seq,
                 (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
                 output_w_ptr,
                 out_ptr,
                 _cublasHandle);

    // layer output dropout.
    if (_pre_or_postLayerNorm)
        _layer_output_dropout.ForwardWithBias(
            bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
    else
        _layer_output_dropout.ForwardWithBias(
            bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);

    if (!_pre_or_postLayerNorm) {
278
279
280
        if (_layer_norm.UseMean())
            _layer_norm.ForwardCheckpoint(
                bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
281
        else
282
283
            _layer_norm.Forward(
                bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    }
}

template <typename T>
void BertTransformerLayer<T>::Backward(int bsz,
                                       const T* grad_output_ptr,
                                       const T* input_ptr,
                                       const T* output_ptr,
                                       const T* inp_norm_ptr,
                                       const T* q_tf_ptr,
                                       const T* k_tf_ptr,
                                       const T* v_tf_ptr,
                                       const T* soft_out_ptr,
                                       const T* ctx_bufB_ptr,
                                       const T* attn_o_inp_ptr,
                                       const T* add_res_ptr,
                                       const T* ff1_inp_ptr,
                                       const T* gelu_inp_ptr,
                                       const T* ff2_inp_ptr,
                                       const T* input_mask_ptr,
                                       const T* attn_qkvw_ptr,
                                       const T* attn_ow_ptr,
                                       const T* attn_nw_ptr,
                                       const T* attn_nb_ptr,
                                       const T* inter_w_ptr,
                                       const T* inter_b_ptr,
                                       const T* output_w_ptr,
                                       const T* norm_w_ptr,
                                       const T* norm_b_ptr,

                                       T* grad_input_ptr,
                                       T* grad_attn_qkvw_ptr,
                                       T* grad_attn_qkvb_ptr,
                                       T* grad_attn_ow_ptr,
                                       T* grad_attn_ob_ptr,
                                       T* grad_attn_nw_ptr,
                                       T* grad_attn_nb_ptr,
                                       T* grad_inter_w_ptr,
                                       T* grad_inter_b_ptr,
                                       T* grad_output_w_ptr,
                                       T* grad_output_b_ptr,
                                       T* grad_norm_w_ptr,
                                       T* grad_norm_b_ptr)
{
    cublasSetStream(_cublasHandle, _stream);

    if (!_stochastic_mode) cudaStreamSynchronize(_stream);

    T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
    size_t small_buf_size = bsz * _seq_length * _hidden_size;
    T* buf_0 = workspace;
    T* buf_1 = buf_0 + small_buf_size;
    T* buf_2 = buf_1 + small_buf_size;
    T* buf_3 = buf_2 + small_buf_size;

339
340
    T* ff2_buf = (_gelu_checkpoint ? buf_2 + (bsz * _seq_length * _intermediate_size)
                                   : buf_3 + small_buf_size);
341
342
343
344
345
346
347
348
    T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);

    cudaStream_t streams[2] = {_stream, _stream};

    int bsz_seq = bsz * _seq_length;
    int bsz_heads = bsz * _heads;

    if (!_pre_or_postLayerNorm) {
349
350
351
352
353
354
355
356
357
        if (_layer_norm.UseMean())
            _layer_norm.Backward(bsz_seq,
                                 grad_output_ptr,
                                 norm_w_ptr,
                                 grad_norm_w_ptr,
                                 grad_norm_b_ptr,
                                 streams,
                                 buf_1,
                                 inp_norm_ptr);
358
359

        else
360
361
362
363
364
365
366
367
368
            _layer_norm.Backward(bsz_seq,
                                 grad_output_ptr,
                                 norm_w_ptr,
                                 norm_b_ptr,
                                 grad_norm_w_ptr,
                                 grad_norm_b_ptr,
                                 streams,
                                 buf_1,
                                 output_ptr);
369
370
371
372
373
374
375
376
377
378
379
    }

    if (_pre_or_postLayerNorm)
        _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
    else
        _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);

    const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
                                     ? buf_0
                                     : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);

380
381
    if (_gelu_checkpoint)
        _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
382
383
384
385
386
387
388
389
390
391
392
    _ff2.Backward(bsz_seq,
                  layer_dropout_buf,
                  (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
                  output_w_ptr,
                  grad_output_w_ptr,
                  grad_output_b_ptr,
                  _cublasHandle,
                  _stream,
                  ff2_buf);

    _gelu.Backward(
393
        bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    _ff1.Backward(bsz_seq,
                  ff2_buf,
                  ff1_inp_ptr,
                  inter_w_ptr,
                  grad_inter_w_ptr,
                  grad_inter_b_ptr,
                  _cublasHandle,
                  _stream,
                  buf_3);

    if (!_pre_or_postLayerNorm)
        launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);

    if (_pre_or_postLayerNorm) {
409
410
411
412
413
414
415
416
417
418
        if (_attn_layer_norm.UseMean())
            _attn_layer_norm.BackwardFusedAdd(bsz_seq,
                                              buf_3,
                                              grad_output_ptr,
                                              attn_nw_ptr,
                                              grad_attn_nw_ptr,
                                              grad_attn_nb_ptr,
                                              streams,
                                              buf_0,
                                              add_res_ptr);
419
420

        else
421
422
423
424
425
426
427
428
429
430
            _attn_layer_norm.BackwardFusedAdd(bsz_seq,
                                              buf_3,
                                              grad_output_ptr,
                                              attn_nw_ptr,
                                              attn_nb_ptr,
                                              grad_attn_nw_ptr,
                                              grad_attn_nb_ptr,
                                              streams,
                                              buf_0,
                                              ff1_inp_ptr);
431
    } else {
432
433
434
435
436
437
438
439
440
        if (_attn_layer_norm.UseMean())
            _attn_layer_norm.Backward(bsz_seq,
                                      buf_2,
                                      attn_nw_ptr,
                                      grad_attn_nw_ptr,
                                      grad_attn_nb_ptr,
                                      streams,
                                      buf_0,
                                      add_res_ptr);
441
442

        else
443
444
445
446
447
448
449
450
451
            _attn_layer_norm.Backward(bsz_seq,
                                      buf_2,
                                      attn_nw_ptr,
                                      attn_nb_ptr,
                                      grad_attn_nw_ptr,
                                      grad_attn_nb_ptr,
                                      streams,
                                      buf_0,
                                      ff1_inp_ptr);
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    }

    _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);

    T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;

    _attn_out_linear.Backward(bsz_seq,
                              attn_output_dropout_buf,
                              attn_o_inp_ptr,
                              attn_ow_ptr,
                              grad_attn_ow_ptr,
                              grad_attn_ob_ptr,
                              _cublasHandle,
                              _stream,
                              buf_1);

    launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);

    if (_attn_prob_dropout.HasDropout()) {
        if (_attn_dropout_checkpoint)
            _attn_prob_dropout.Forward(
                bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);

        _attn_context.Backward(bsz_heads,
                               buf_2,
                               v_tf_ptr,
                               (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
                               _cublasHandle,
                               buf_3,
                               ff2_buf);
    } else
        _attn_context.Backward(
            bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);

    _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);

    _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);

    _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);

    launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);

    if (_pre_or_postLayerNorm)
        _qkv_linear.Backward(bsz_seq,
                             ff2_buf,
                             inp_norm_ptr,
                             attn_qkvw_ptr,
                             grad_attn_qkvw_ptr,
                             grad_attn_qkvb_ptr,
                             _cublasHandle,
                             _stream,
                             buf_2);
    else
        _qkv_linear.Backward(bsz_seq,
                             ff2_buf,
                             input_ptr,
                             attn_qkvw_ptr,
                             grad_attn_qkvw_ptr,
                             grad_attn_qkvb_ptr,
                             _cublasHandle,
                             _stream,
                             buf_2);

    if (_pre_or_postLayerNorm) {
516
517
518
519
520
521
522
523
524
525
        if (_layer_norm.UseMean())
            _layer_norm.BackwardFusedAdd(bsz_seq,
                                         buf_2,
                                         buf_0,
                                         norm_w_ptr,
                                         grad_norm_w_ptr,
                                         grad_norm_b_ptr,
                                         streams,
                                         grad_input_ptr,
                                         input_ptr);
526
527

        else
528
529
530
531
532
533
534
535
536
537
            _layer_norm.BackwardFusedAdd(bsz_seq,
                                         buf_2,
                                         buf_0,
                                         norm_w_ptr,
                                         norm_b_ptr,
                                         grad_norm_w_ptr,
                                         grad_norm_b_ptr,
                                         streams,
                                         grad_input_ptr,
                                         inp_norm_ptr);
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    } else
        launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}

template <typename T>
void BertTransformerLayer<T>::SetTrainingMode(bool training)
{
    // Dropout will be skipped when not in training model.
    _attn_prob_dropout.SetTrainingMode(training);
    _attn_output_dropout.SetTrainingMode(training);
    _layer_output_dropout.SetTrainingMode(training);
}

template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
                                                     uint8_t* attn_output_dropout_mask_ptr,
554
555
556
557
558
                                                     uint8_t* layer_output_dropout_mask_ptr,
                                                     T* attn_layer_norm_var,
                                                     T* attn_layer_norm_mean,
                                                     T* layer_norm_var,
                                                     T* layer_norm_mean)
559
560
561
562
{
    _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
    _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
    _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
563
564
565
566
567
568
569
570

    _attn_layer_norm.SetVar(attn_layer_norm_var);
    _attn_layer_norm.SetMean(attn_layer_norm_mean);
    _layer_norm.SetVar(layer_norm_var);
    _layer_norm.SetMean(layer_norm_mean);
}

template <typename T>
571
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
572
573
574
575
576
577
578
{
    _seq_length = seq_len;

    _softmax.SetSeqLength(_seq_length);
    _attn_prob_dropout.SetDimension(_seq_length);
    _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
    _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
579
580
581
582
583
584
585
586
587
588
}

template <typename T>
int create_transformer_layer(int layer_id,
                             int batch_size,
                             int hidden_dim,
                             int num_heads,
                             int intermediate_size,
                             float attn_dropout_ratio,
                             float hidden_dropout_ratio,
589
                             float layer_norm_eps,
590
591
592
593
594
595
596
597
598
599
                             int seed,
                             bool pre_or_postLayerNorm,
                             bool test_gemm,
                             bool attn_dropout_checkpoint,
                             bool normalize_invertible,
                             bool gelu_checkpoint,
                             bool stochastic_mode)
{
    Context::Instance().SetSeed(seed);
    Context::Instance().TestGemmFP16(
600
        test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
601
602
603
604
605
606

    auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
                                                           batch_size,
                                                           hidden_dim,
                                                           num_heads,
                                                           intermediate_size,
607
                                                           init_seq_length,
608
609
                                                           attn_dropout_ratio,
                                                           hidden_dropout_ratio,
610
                                                           layer_norm_eps,
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                                                           pre_or_postLayerNorm,
                                                           Context::Instance().GetGemmAlgos(),
                                                           attn_dropout_checkpoint,
                                                           normalize_invertible,
                                                           gelu_checkpoint,
                                                           stochastic_mode);

    s_transformer_layers[layer_id] = layer;

    std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";

    std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
              << std::endl;

    return 0;
}

template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
                                                  const torch::Tensor& input,
                                                  const torch::Tensor& input_mask,
                                                  const torch::Tensor& attn_qkvw,
                                                  const torch::Tensor& attn_qkvb,
                                                  const torch::Tensor& attn_ow,
                                                  const torch::Tensor& attn_ob,
                                                  const torch::Tensor& attn_nw,
                                                  const torch::Tensor& attn_nb,
                                                  const torch::Tensor& inter_w,
                                                  const torch::Tensor& inter_b,
                                                  const torch::Tensor& output_w,
                                                  const torch::Tensor& output_b,
                                                  const torch::Tensor& norm_w,
                                                  const torch::Tensor& norm_b,
                                                  bool training_mode,
                                                  bool prelayernorm,
                                                  bool attn_dropout_checkpoint,
                                                  bool normalize_invertible,
                                                  bool gelu_checkpoint)
{
    CHECK_INPUT(input);
    CHECK_INPUT(input_mask);
    CHECK_INPUT(attn_qkvw);
    CHECK_INPUT(attn_qkvb);
    CHECK_INPUT(attn_ow);
    CHECK_INPUT(attn_ob);
    CHECK_INPUT(attn_nw);
    CHECK_INPUT(attn_nb);
    CHECK_INPUT(inter_w);
    CHECK_INPUT(inter_b);
    CHECK_INPUT(output_w);
    CHECK_INPUT(output_b);
    CHECK_INPUT(norm_w);
    CHECK_INPUT(norm_b);

    int bsz = input.size(0);

    const T* input_ptr = (const T*)input.data_ptr();
    const T* input_mask_ptr = (const T*)input_mask.data_ptr();
    const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
    const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
    const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
    const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
    const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
    const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
    const T* inter_w_ptr = (const T*)inter_w.data_ptr();
    const T* inter_b_ptr = (const T*)inter_b.data_ptr();
    const T* output_w_ptr = (const T*)output_w.data_ptr();
    const T* output_b_ptr = (const T*)output_b.data_ptr();
    const T* norm_w_ptr = (const T*)norm_w.data_ptr();
    const T* norm_b_ptr = (const T*)norm_b.data_ptr();

    auto output = torch::empty_like(input);
    T* out_ptr = (T*)output.data_ptr();

    auto options = torch::TensorOptions()
                       .dtype(input.options().dtype())
                       .layout(torch::kStrided)
                       .device(torch::kCUDA)
                       .requires_grad(true);

    auto uint8_options = torch::TensorOptions()
                             .dtype(torch::kInt8)
                             .layout(torch::kStrided)
                             .device(torch::kCUDA)
                             .requires_grad(false);

    std::shared_ptr<BertTransformerLayer<T>> layer =
        std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

700
701
702
    int seq_len = layer->GetSeqLength();
    if (input.size(1) != seq_len) {
        seq_len = input.size(1);
703
        layer->SetSeqLength(seq_len);
704
705
    }

706
707
708
709
710
711
712
713
714
715
    auto workspace = torch::empty({get_workspace_size<T>(bsz,
                                                         seq_len,
                                                         layer->GetHiddenSize(),
                                                         layer->GetIntermediateSize(),
                                                         layer->GetNumHeads(),
                                                         layer->IsTrainingMode(),
                                                         layer->GeluCheckpoint())},
                                  options);
    Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

716
717
718
    auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
    auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
    auto attn_o_inp = torch::empty_like(input);
719
    auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
720
721

    auto attn_prob_dropout_mask =
722
        torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
723
    auto attn_output_dropout_mask =
724
        torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
725
    auto layer_output_dropout_mask =
726
727
728
729
730
731
        torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);

    auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
    auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
    auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
    auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
732
733
734
735

    T* inp_norm_ptr = (T*)inp_norm.data_ptr();
    T* add_res_ptr = (T*)add_res.data_ptr();
    T* q_tf_ptr = (T*)qkv_tf.data_ptr();
736
737
    T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0));  //(T*)k_tf.data_ptr();
    T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0));  //(T*)v_tf.data_ptr();
738
739
    T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();

740
    torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
741
    torch::Tensor gelu_inp =
742
        (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
743
744
745
746
747
    auto ff1_inp = torch::empty_like(input);
    T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
    T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
    T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();

748
749
    torch::Tensor soft_out =
        torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
750
751
752
    torch::Tensor ctx_bufB =
        (attn_dropout_checkpoint
             ? soft_out
753
             : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
754
755
756
757
758
759
    T* soft_out_ptr = (T*)soft_out.data_ptr();
    T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();

    layer->SetTrainingMode(training_mode);
    layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
                                  (uint8_t*)attn_output_dropout_mask.data_ptr(),
760
761
762
763
764
                                  (uint8_t*)layer_output_dropout_mask.data_ptr(),
                                  (T*)attn_layer_norm_var.data_ptr(),
                                  (T*)attn_layer_norm_mean.data_ptr(),
                                  (T*)layer_norm_var.data_ptr(),
                                  (T*)layer_norm_mean.data_ptr());
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805

    layer->Forward(bsz,
                   input_ptr,
                   input_mask_ptr,
                   attn_qkvw_ptr,
                   attn_qkvb_ptr,
                   attn_ow_ptr,
                   attn_ob_ptr,
                   attn_nw_ptr,
                   attn_nb_ptr,
                   inter_w_ptr,
                   inter_b_ptr,
                   output_w_ptr,
                   output_b_ptr,
                   norm_w_ptr,
                   norm_b_ptr,
                   out_ptr,
                   inp_norm_ptr,
                   q_tf_ptr,
                   k_tf_ptr,
                   v_tf_ptr,
                   soft_out_ptr,
                   ctx_bufB_ptr,
                   attn_o_inp_ptr,
                   add_res_ptr,
                   ff1_inp_ptr,
                   gelu_inp_ptr,
                   ff2_inp_ptr);

    return {output,
            inp_norm,
            qkv_tf,
            soft_out,
            ctx_bufB,
            attn_o_inp,
            add_res,
            ff1_inp,
            gelu_inp,
            ff2_inp,
            attn_prob_dropout_mask,
            attn_output_dropout_mask,
806
807
808
809
810
            layer_output_dropout_mask,
            attn_layer_norm_var,
            attn_layer_norm_mean,
            layer_norm_var,
            layer_norm_mean};
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
}

template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
                                                   const torch::Tensor& grad_output,
                                                   const torch::Tensor& output,
                                                   const torch::Tensor& inp_norm,
                                                   const torch::Tensor& qkv_tf,
                                                   const torch::Tensor& soft_out,
                                                   const torch::Tensor& ctx_bufB,
                                                   const torch::Tensor& attn_o_inp,
                                                   const torch::Tensor& add_res,
                                                   const torch::Tensor& ff1_inp,
                                                   const torch::Tensor& gelu_inp,
                                                   const torch::Tensor& ff2_inp,
                                                   const torch::Tensor& attn_prob_dropout_mask,
                                                   const torch::Tensor& attn_output_dropout_mask,
                                                   const torch::Tensor& layer_output_dropout_mask,
829
830
831
832
                                                   const torch::Tensor& attn_layer_norm_var,
                                                   const torch::Tensor& attn_layer_norm_mean,
                                                   const torch::Tensor& layer_norm_var,
                                                   const torch::Tensor& layer_norm_mean,
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
                                                   const torch::Tensor& input,
                                                   const torch::Tensor& input_mask,
                                                   const torch::Tensor& attn_qkvw,
                                                   const torch::Tensor& attn_qkvb,
                                                   const torch::Tensor& attn_ow,
                                                   const torch::Tensor& attn_ob,
                                                   const torch::Tensor& attn_nw,
                                                   const torch::Tensor& attn_nb,
                                                   const torch::Tensor& inter_w,
                                                   const torch::Tensor& inter_b,
                                                   const torch::Tensor& output_w,
                                                   const torch::Tensor& output_b,
                                                   const torch::Tensor& norm_w,
                                                   const torch::Tensor& norm_b)
{
    auto g_output = grad_output.contiguous();
    CHECK_INPUT(g_output);
    CHECK_INPUT(output);
    CHECK_INPUT(inp_norm);
    CHECK_INPUT(qkv_tf);
    CHECK_INPUT(add_res);
    CHECK_INPUT(soft_out);
    CHECK_INPUT(ctx_bufB);
    CHECK_INPUT(attn_o_inp);
    CHECK_INPUT(ff1_inp);
    CHECK_INPUT(gelu_inp);
    CHECK_INPUT(ff2_inp);
    CHECK_INPUT(input);
    CHECK_INPUT(input_mask);
    CHECK_INPUT(attn_qkvw);
    CHECK_INPUT(attn_qkvb);
    CHECK_INPUT(attn_ow);
    CHECK_INPUT(attn_ob);
    CHECK_INPUT(attn_nw);
    CHECK_INPUT(attn_nb);
    CHECK_INPUT(inter_w);
    CHECK_INPUT(inter_b);
    CHECK_INPUT(output_w);
    CHECK_INPUT(output_b);
    CHECK_INPUT(norm_w);
    CHECK_INPUT(norm_b);

    int bsz = g_output.size(0);
876

877
878
879
    std::shared_ptr<BertTransformerLayer<T>> layer =
        std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

880
881
882
    int seq_len = layer->GetSeqLength();
    if (g_output.size(1) != seq_len) {
        seq_len = g_output.size(1);
883
        layer->SetSeqLength(seq_len);
884
    }
885
886
887
888
889
    auto options = torch::TensorOptions()
                       .dtype(g_output.options().dtype())
                       .layout(torch::kStrided)
                       .device(torch::kCUDA)
                       .requires_grad(true);
890
891
892
893
894
895
896
    auto workspace = torch::empty({get_workspace_size<T>(bsz,
                                                         seq_len,
                                                         layer->GetHiddenSize(),
                                                         layer->GetIntermediateSize(),
                                                         layer->GetNumHeads(),
                                                         layer->IsTrainingMode(),
                                                         layer->GeluCheckpoint())},
897
                                  options);
898
899
    Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
    auto grad_input = torch::empty_like(input);
    auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
    auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
    auto grad_attn_ow = torch::empty_like(attn_ow);
    auto grad_attn_ob = torch::empty_like(attn_ob);
    auto grad_attn_nw = torch::empty_like(attn_nw);
    auto grad_attn_nb = torch::empty_like(attn_nb);
    auto grad_inter_w = torch::empty_like(inter_w);
    auto grad_inter_b = torch::empty_like(inter_b);
    auto grad_output_w = torch::empty_like(output_w);
    auto grad_output_b = torch::empty_like(output_b);
    auto grad_norm_w = torch::empty_like(norm_w);
    auto grad_norm_b = torch::empty_like(norm_b);

    // inputs.
    const T* grad_output_ptr = (const T*)g_output.data_ptr();
    const T* input_ptr = (const T*)input.data_ptr();
    const T* output_ptr = (const T*)output.data_ptr();
    const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
    const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
    const T* add_res_ptr = (const T*)add_res.data_ptr();
    const T* k_tf_ptr =
        q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0));  //(const T*)k_tf.data_ptr();
    const T* v_tf_ptr =
        k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0));  //(const T*)v_tf.data_ptr();
    const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
    const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
    const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
    const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
    const T* soft_out_ptr = (const T*)soft_out.data_ptr();
    const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
    const T* input_mask_ptr = (const T*)input_mask.data_ptr();
    const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
    const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
    const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
    const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
    const T* inter_w_ptr = (const T*)inter_w.data_ptr();
    const T* inter_b_ptr = (const T*)inter_b.data_ptr();
    const T* output_w_ptr = (const T*)output_w.data_ptr();
    const T* norm_w_ptr = (const T*)norm_w.data_ptr();
    const T* norm_b_ptr = (const T*)norm_b.data_ptr();

    // outputs.
    T* grad_input_ptr = (T*)grad_input.data_ptr();
    T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
    T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
    T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
    T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
    T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
    T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
    T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
    T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
    T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
    T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
    T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
    T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();

    layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
                                  (uint8_t*)attn_output_dropout_mask.data_ptr(),
959
960
961
962
963
                                  (uint8_t*)layer_output_dropout_mask.data_ptr(),
                                  (T*)attn_layer_norm_var.data_ptr(),
                                  (T*)attn_layer_norm_mean.data_ptr(),
                                  (T*)layer_norm_var.data_ptr(),
                                  (T*)layer_norm_mean.data_ptr());
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040

    layer->Backward(bsz,
                    grad_output_ptr,
                    input_ptr,
                    output_ptr,
                    inp_norm_ptr,
                    q_tf_ptr,
                    k_tf_ptr,
                    v_tf_ptr,
                    soft_out_ptr,
                    ctx_bufB_ptr,
                    attn_o_inp_ptr,
                    add_res_ptr,
                    ff1_inp_ptr,
                    gelu_inp_ptr,
                    ff2_inp_ptr,
                    input_mask_ptr,
                    attn_qkvw_ptr,
                    attn_ow_ptr,
                    attn_nw_ptr,
                    attn_nb_ptr,
                    inter_w_ptr,
                    inter_b_ptr,
                    output_w_ptr,
                    norm_w_ptr,
                    norm_b_ptr,

                    grad_input_ptr,
                    grad_attn_qkvw_ptr,
                    grad_attn_qkvb_ptr,
                    grad_attn_ow_ptr,
                    grad_attn_ob_ptr,
                    grad_attn_nw_ptr,
                    grad_attn_nb_ptr,
                    grad_inter_w_ptr,
                    grad_inter_b_ptr,
                    grad_output_w_ptr,
                    grad_output_b_ptr,
                    grad_norm_w_ptr,
                    grad_norm_b_ptr);

    return {grad_input,
            grad_attn_qkvw,
            grad_attn_qkvb,
            grad_attn_ow,
            grad_attn_ob,
            grad_attn_nw,
            grad_attn_nb,
            grad_inter_w,
            grad_inter_b,
            grad_output_w,
            grad_output_b,
            grad_norm_w,
            grad_norm_b};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("forward_fp32",
          &ds_transformer_forward<float>,
          "DeepSpeed Transformer forward with fp32 (CUDA)");
    m.def("forward_fp16",
          &ds_transformer_forward<__half>,
          "DeepSpeed Transformer forward with fp16 (CUDA)");
    m.def("backward_fp32",
          &ds_transformer_backward<float>,
          "DeepSpeed Transformer backward with fp32 (CUDA)");
    m.def("backward_fp16",
          &ds_transformer_backward<__half>,
          "DeepSpeed Transformer backward with fp16 (CUDA)");
    m.def("create_transformer_layer_fp32",
          &create_transformer_layer<float>,
          "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
    m.def("create_transformer_layer_fp16",
          &create_transformer_layer<__half>,
          "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
}