dropout_kernels.cu 29.9 KB
Newer Older
1
2
#include "custom_cuda_layers.h"

3
4
const int unroll_factor = 4;

5
6
7
8
9
10
11
12
13
14
15
16
17
__global__ void dropout_kernel(const int N,
                               const float ratio,
                               float* out,
                               const float* Xdata,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

18
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
19
20
    {
        float4 rand = curand_uniform4(&state);
21
        uint8_t m[unroll_factor];
22
23
24
25
26
27

        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

28
        int i = j * unroll_factor;
29
30
31
32
33
34
35
36
37
38
39

        mask[i] = (uint8_t)m[0];
        mask[i + 1] = (uint8_t)m[1];
        mask[i + 2] = (uint8_t)m[2];
        mask[i + 3] = (uint8_t)m[3];

        out[i] = Xdata[i] * scale * m[0];
        out[i + 1] = Xdata[i + 1] * scale * m[1];
        out[i + 2] = Xdata[i + 2] * scale * m[2];
        out[i + 3] = Xdata[i + 3] * scale * m[3];
    }
40
41
42
43
44
45
46
47
48
49
50
51
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            out[i] = Xdata[i] * scale * m;
            mask[i] = m;
        }
    }
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
}

__global__ void dropout_kernel(const int N,
                               const float ratio,
                               __half* out,
                               const __half* Xdata,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);

    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

#ifdef __STOCHASTIC_MODE__

    const __half2 h_scale = __float2half2_rn(scale);
    const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
    float2* out_cast = reinterpret_cast<float2*>(out);
    uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);

    uint32_t m_32;
    uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);

    float2 result_f;
    __half2* result_h = reinterpret_cast<__half2*>(&result_f);
    __half2 mask_h[2];
    float2 mask_f[2];

83
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
84
85
86
87
88
89
90
91
92
93
94
95
96
    {
        float2 x_f = x_cast[j];
        __half2* x_h = reinterpret_cast<__half2*>(&x_f);

        float4 rand = curand_uniform4(&state);

        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

        float* mask_f_data = &mask_f[0].x;
#pragma unroll
97
        for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
98
99
100
101
102
103
104
105
106
107
108
109
110
111

        mask_h[0] = __float22half2_rn(mask_f[0]);
        mask_h[1] = __float22half2_rn(mask_f[1]);

        result_h[0] = x_h[0] * h_scale * mask_h[0];
        result_h[1] = x_h[1] * h_scale * mask_h[1];

        out_cast[j] = result_f;

        mask_cast[j] = m_32;
    }

#else

112
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
113
    {
114
        int i = j * unroll_factor;
115
116
117
118
119
120

        const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
        float2 vals_half_f[2];
        vals_half_f[0] = __half22float2(vals_half[0]);
        vals_half_f[1] = __half22float2(vals_half[1]);

121
        uint8_t m[unroll_factor];
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        float4 rand = curand_uniform4(&state);
        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

        out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
        out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
        out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
        out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);

        mask[i] = m[0];
        mask[i + 1] = m[1];
        mask[i + 2] = m[2];
        mask[i + 3] = m[3];
    }

#endif
140
141
142
143
144
145
146
147
148
149
150
151
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            out[i] = __float2half((float)Xdata[i] * scale * m);
            mask[i] = m;
        }
    }
152
153
154
155
156
157
158
159
160
161
}

__global__ void dropout_kernel_bwd(const int N,
                                   const float ratio,
                                   const float* Xdata,
                                   float* out,
                                   uint8_t* mask,
                                   std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
162
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
163
    {
164
        int i = j * unroll_factor;
165
166
167
168
169
170

        out[i] = mask[i] ? Xdata[i] * scale : 0.0;
        out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
        out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
        out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
    }
171
172
173
174
175
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
    }
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
}

__global__ void dropout_kernel_bwd(const int N,
                                   const float ratio,
                                   const __half* Xdata,
                                   __half* out,
                                   uint8_t* mask,
                                   std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);

#ifdef __STOCHASTIC_MODE__

    const __half2 h_scale = __float2half2_rn(scale);

    const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
    float2* out_cast = reinterpret_cast<float2*>(out);
    uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);

195
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
196
197
198
199
    {
        float2 x_f = x_cast[j];
        __half2* x_h = reinterpret_cast<__half2*>(&x_f);

200
201
202
        uint32_t m_32 = mask_cast[j];
        uint8_t* m = (uint8_t*)&m_32;

203
204
205
206
207
        __half2 mask_h[2];
        float2 mask_f[2];

        float* mask_f_data = &mask_f[0].x;
#pragma unroll
208
        for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

#pragma unroll
        for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);

        float2 result_f;
        __half2* result_h = reinterpret_cast<__half2*>(&result_f);

        result_h[0] = x_h[0] * h_scale * mask_h[0];
        result_h[1] = x_h[1] * h_scale * mask_h[1];

        out_cast[j] = result_f;
    }

#else

    const __half h_scale = __float2half(scale);
    const __half h_zero = __float2half(0.0);

227
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
228
    {
229
        int i = j * unroll_factor;
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);

        uint8_t* m = mask + i;

        float2 vals_half_f[2];

        vals_half_f[0] = __half22float2(vals_half[0]);
        vals_half_f[1] = __half22float2(vals_half[1]);

        out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
        out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
        out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
        out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
    }

#endif
247
248
249
250
251
252
253
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        for (int i = high_index; i < N; i++) {
            out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
        }
    }
254
255
256
257
258
259
260
261
262
263
264
265
}

template <typename T>
void launch_dropout(T* out,
                    const T* vals,
                    uint8_t* mask,
                    int total_count,
                    int dim,
                    float ratio,
                    cudaStream_t stream,
                    bool bwd)
{
266
267
268
    assert(unroll_factor == 4);

    dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    dim3 block_dim = DS_CUDA_NUM_THREADS;

    if (dim > 512) {
        block_dim.x >>= 1;
        grid_dim.x <<= 1;
    }
    uint64_t inc = total_count / grid_dim.x / block_dim.x;
    std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
    if (bwd)
        dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
            total_count, ratio, vals, out, mask, seed);
    else
        dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
            total_count, ratio, out, vals, mask, seed);
}

template void launch_dropout(float* out,
                             const float* vals,
                             uint8_t* mask,
                             int total_count,
                             int dim,
                             float ratio,
                             cudaStream_t stream,
                             bool);
template void launch_dropout(__half* out,
                             const __half* vals,
                             uint8_t* mask,
                             int total_count,
                             int dim,
                             float ratio,
                             cudaStream_t stream,
                             bool);

__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
    CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}

__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
    const __half2 h_scale = __float2half2_rn(scale);
    float2* x_cast = reinterpret_cast<float2*>(Xdata);
    uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);

313
    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
314
    {
315
316
317
318
319
320
321
322
323
324
        float2 x_data = x_cast[j];
        uint32_t m_32 = mask_cast[j];
        uint8_t* m = (uint8_t*)&m_32;

        float2 result_f;
        __half2* result_h = reinterpret_cast<__half2*>(&result_f);

#ifdef __STOCHASTIC_MODE__

        __half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
325
326
327
328
329
        __half2 mask_h[2];
        float2 mask_f[2];

        float* mask_f_data = &mask_f[0].x;
#pragma unroll
330
        for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
331
332
333
334
335
336
337
338
339

        mask_h[0] = __float22half2_rn(mask_f[0]);
        mask_h[1] = __float22half2_rn(mask_f[1]);

        result_h[0] = x_data_h[0] * h_scale * mask_h[0];
        result_h[1] = x_data_h[1] * h_scale * mask_h[1];

#else

340
341
342
343
344
345
346
347
348
349
        __half* x_data_h = reinterpret_cast<__half*>(&x_data);
        float2 result[2];

        result[0].x = (float)x_data_h[0] * scale * m[0];
        result[0].y = (float)x_data_h[1] * scale * m[1];
        result[1].x = (float)x_data_h[2] * scale * m[2];
        result[1].y = (float)x_data_h[3] * scale * m[3];

        result_h[0] = __float22half2_rn(result[0]);
        result_h[1] = __float22half2_rn(result[1]);
350
351

#endif
352
353
354
355
356
357
358
359
360
        x_cast[j] = result_f;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        for (int i = high_index; i < N; i++) {
            Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
        }
    }
361
362
363
364
365
}

template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
{
366
367
    assert(unroll_factor == 4);

368
    const float scale = 1. / (1. - ratio);
369
370
371
372
    dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
                          DS_CUDA_NUM_THREADS,
                          0,
                          stream>>>(total_count, scale, vals, mask);
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
}

template void launch_dropout_grad(float* vals,
                                  uint8_t* mask,
                                  int total_count,
                                  float ratio,
                                  cudaStream_t stream);
template void launch_dropout_grad(__half* vals,
                                  uint8_t* mask,
                                  int total_count,
                                  float ratio,
                                  cudaStream_t stream);

__global__ void dropout_grad_kernel(const int N,
                                    const float scale,
                                    const float* Xdata,
                                    float* out,
                                    uint8_t* mask)
{
    CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}

__global__ void dropout_grad_kernel(const int N,
                                    const float scale,
                                    const __half* Xdata,
                                    __half* out,
                                    uint8_t* mask)
{
401
402
403
404
405
406
407
408
    const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
    float2* out_cast = reinterpret_cast<float2*>(out);
    const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);

    float2 result_f;
    __half2* result_h = reinterpret_cast<__half2*>(&result_f);

    CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
409
    {
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        float2 x_data = x_cast[j];
        uint32_t m_32 = mask_cast[j];
        uint8_t* m = (uint8_t*)&m_32;

        __half* x_data_h = reinterpret_cast<__half*>(&x_data);
        float2 result[2];

        result[0].x = (float)x_data_h[0] * scale * m[0];
        result[0].y = (float)x_data_h[1] * scale * m[1];
        result[1].x = (float)x_data_h[2] * scale * m[2];
        result[1].y = (float)x_data_h[3] * scale * m[3];

        result_h[0] = __float22half2_rn(result[0]);
        result_h[1] = __float22half2_rn(result[1]);

        out_cast[j] = result_f;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        for (int i = high_index; i < N; i++) {
            out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
        }
433
434
435
436
437
438
439
440
441
442
443
    }
}

template <typename T>
void launch_dropout_grad(T* vals_out,
                         const T* vals,
                         uint8_t* mask,
                         int total_count,
                         float ratio,
                         cudaStream_t stream)
{
444
445
    assert(unroll_factor == 4);

446
    const float scale = 1. / (1. - ratio);
447
448
449
450
    dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
                          DS_CUDA_NUM_THREADS,
                          0,
                          stream>>>(total_count, scale, vals, vals_out, mask);
451
452
453
454
455
456
457
458
459
460
461
462
463
464
}
template void launch_dropout_grad(float*,
                                  const float* vals,
                                  uint8_t* mask,
                                  int total_count,
                                  float ratio,
                                  cudaStream_t stream);
template void launch_dropout_grad(__half*,
                                  const __half* vals,
                                  uint8_t* mask,
                                  int total_count,
                                  float ratio,
                                  cudaStream_t stream);

465
466
__global__ void dropout_kernel(const int N,
                               const int dim,
467
468
469
470
471
472
473
474
                               const float ratio,
                               const float* bias,
                               float* Xdata,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
475
    int tid = threadIdx.x % (dim / unroll_factor);
476
477
478
479
480

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

    float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
481
    uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
482
483
    const float4* bias_cast = reinterpret_cast<const float4*>(bias);

484
    CUDA_1D_KERNEL_LOOP(j, N)
485
486
    {
        float4 rand = curand_uniform4(&state);
487
488
        uint32_t m_32;
        uint8_t* m = (uint8_t*)&m_32;
489
490
491
492
493
494

        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

495
        float4 x_data = Xdata_cast[j];
496
        float4 b_data = bias_cast[j % (dim / unroll_factor)];
497
498
499
500
501
502
503
504
505
506
507

        x_data.x += b_data.x;
        x_data.y += b_data.y;
        x_data.z += b_data.z;
        x_data.w += b_data.w;

        x_data.x = x_data.x * scale * m[0];
        x_data.y = x_data.y * scale * m[1];
        x_data.z = x_data.z * scale * m[2];
        x_data.w = x_data.w * scale * m[3];

508
509
510
511
512
513
514
515
516
517
        mask_32[j] = m_32;
        Xdata_cast[j] = x_data;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
518
            float x_data = Xdata[i] + bias[i % dim];
519
520
521
522
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            Xdata[i] = x_data * scale * m;
            mask[i] = m;
        }
523
524
525
    }
}

526
527
__global__ void dropout_kernel(const int N,
                               const int dim,
528
529
530
531
532
533
534
535
                               const float ratio,
                               const __half* bias,
                               __half* Xdata,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
536
    int tid = threadIdx.x % (dim / unroll_factor);
537
538
539
540
541

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

    float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
542
    uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
543
544
    const float2* bias_cast = reinterpret_cast<const float2*>(bias);

545
    CUDA_1D_KERNEL_LOOP(j, N)
546
547
548
549
550
551
552
553
554
    {
        float4 rand = curand_uniform4(&state);

        float2 data_f;
        __half2* data_h = reinterpret_cast<__half2*>(&data_f);

        float2 bias_f;
        __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);

555
        data_f = Xdata_cast[j];
556
        bias_f = bias_cast[j % (dim / unroll_factor)];
557
558
559
560
561
562
563
564
565
566
567
568

        float2 data_h_0 = __half22float2(data_h[0]);
        float2 data_h_1 = __half22float2(data_h[1]);

        float2 bias_h_0 = __half22float2(bias_h[0]);
        float2 bias_h_1 = __half22float2(bias_h[1]);

        data_h_0.x += bias_h_0.x;
        data_h_0.y += bias_h_0.y;
        data_h_1.x += bias_h_1.x;
        data_h_1.y += bias_h_1.y;

569
570
        uint32_t m_32;
        uint8_t* m = (uint8_t*)&m_32;
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587

        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

        data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
        data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
        data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
        data_h_1.y = __float2half(data_h_1.y * scale * m[3]);

        float2 result_f;
        __half2* result_h = reinterpret_cast<__half2*>(&result_f);

        result_h[0] = __float22half2_rn(data_h_0);
        result_h[1] = __float22half2_rn(data_h_1);

588
589
590
591
592
593
594
595
596
597
        Xdata_cast[j] = result_f;
        mask_32[j] = m_32;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
598
            float x_data = (float)Xdata[i] + (float)bias[i % dim];
599
600
601
602
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            Xdata[i] = __float2half(x_data * scale * m);
            mask[i] = m;
        }
603
604
605
606
607
608
609
610
611
612
613
614
    }
}

template <typename T>
void launch_dropout(T* out,
                    const T* bias,
                    uint8_t* mask,
                    int batch,
                    int dim,
                    float ratio,
                    cudaStream_t stream)
{
615
616
617
618
619
620
    assert(unroll_factor == 4);

    int total_count = batch * dim / unroll_factor;

    dim3 grid_dim = DS_GET_BLOCKS(total_count);
    dim3 block_dim = DS_CUDA_NUM_THREADS;
621
622
623
624

    uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
    std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);

625
626
    dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
        total_count, dim, ratio, bias, out, mask, seed);
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
}

template void launch_dropout(float*,
                             const float* bias,
                             uint8_t* mask,
                             int batch,
                             int dim,
                             float ratio,
                             cudaStream_t stream);
template void launch_dropout(__half*,
                             const __half* bias,
                             uint8_t* mask,
                             int batch,
                             int dim,
                             float ratio,
                             cudaStream_t stream);

644
645
__global__ void dropout_kernel(const int N,
                               const int dim,
646
647
648
649
650
651
652
653
654
655
                               const float ratio,
                               const float* input,
                               const float* residual,
                               const float* bias,
                               float* out,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
656
    int tid = threadIdx.x % (dim / unroll_factor);
657
658
659
660
661

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

    float4* out_cast = reinterpret_cast<float4*>(out);
662
663
    uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);

664
665
666
667
    const float4* bias_cast = reinterpret_cast<const float4*>(bias);
    const float4* residual_cast = reinterpret_cast<const float4*>(residual);
    const float4* input_cast = reinterpret_cast<const float4*>(input);

668
    CUDA_1D_KERNEL_LOOP(j, N)
669
670
    {
        float4 rand = curand_uniform4(&state);
671
672
673
674

        uint32_t m_32;
        uint8_t* m = (uint8_t*)&m_32;

675
676
677
678
679
        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

680
        float4 out_data;
681
        float4 b_data = bias_cast[j % (dim / unroll_factor)];
682
683
        float4 res_data = residual_cast[j];
        float4 inp_data = input_cast[j];
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

        out_data.x = (b_data.x + inp_data.x);
        out_data.y = (b_data.y + inp_data.y);
        out_data.z = (b_data.z + inp_data.z);
        out_data.w = (b_data.w + inp_data.w);

        out_data.x = out_data.x * scale * m[0];
        out_data.y = out_data.y * scale * m[1];
        out_data.z = out_data.z * scale * m[2];
        out_data.w = out_data.w * scale * m[3];

        out_data.x += res_data.x;
        out_data.y += res_data.y;
        out_data.z += res_data.z;
        out_data.w += res_data.w;

700
701
702
703
704
705
706
707
708
709
        mask_32[j] = m_32;
        out_cast[j] = out_data;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
710
            float x_data = input[i] + bias[i % dim];
711
712
713
714
715
716
717
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            x_data = x_data * scale * m;
            x_data += residual[i];

            out[i] = x_data;
            mask[i] = m;
        }
718
719
720
    }
}

721
722
__global__ void dropout_kernel(const int N,
                               const int dim,
723
724
725
726
727
728
729
730
731
732
                               const float ratio,
                               const __half* input,
                               const __half* residual,
                               const __half* bias,
                               __half* out,
                               uint8_t* mask,
                               std::pair<uint64_t, uint64_t> seed)
{
    const float scale = 1. / (1. - ratio);
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
733
    int tid = threadIdx.x % (dim / unroll_factor);
734
735
736
737
738

    curandStatePhilox4_32_10_t state;
    curand_init(seed.first, idx, seed.second, &state);

    float2* out_cast = reinterpret_cast<float2*>(out);
739
740
    uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);

741
742
743
744
    const float2* bias_cast = reinterpret_cast<const float2*>(bias);
    const float2* residual_cast = reinterpret_cast<const float2*>(residual);
    const float2* input_cast = reinterpret_cast<const float2*>(input);

745
    CUDA_1D_KERNEL_LOOP(j, N)
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    {
        float4 rand = curand_uniform4(&state);

        float2 data_f;
        __half2* data_h = reinterpret_cast<__half2*>(&data_f);

        float2 bias_f;
        __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);

        float2 residual_f;
        __half2* residual_h = reinterpret_cast<__half2*>(&residual_f);

        float2 input_f;
        __half2* input_h = reinterpret_cast<__half2*>(&input_f);

761
        bias_f = bias_cast[j % (dim / unroll_factor)];
762
763
        residual_f = residual_cast[j];
        input_f = input_cast[j];
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

        float2 data_h_0 = __half22float2(data_h[0]);
        float2 data_h_1 = __half22float2(data_h[1]);

        float2 bias_h_0 = __half22float2(bias_h[0]);
        float2 bias_h_1 = __half22float2(bias_h[1]);

        float2 residual_h_0 = __half22float2(residual_h[0]);
        float2 residual_h_1 = __half22float2(residual_h[1]);

        float2 input_h_0 = __half22float2(input_h[0]);
        float2 input_h_1 = __half22float2(input_h[1]);

        data_h_0.x = (bias_h_0.x + input_h_0.x);
        data_h_0.y = (bias_h_0.y + input_h_0.y);
        data_h_1.x = (bias_h_1.x + input_h_1.x);
        data_h_1.y = (bias_h_1.y + input_h_1.y);

782
783
        uint32_t m_32;
        uint8_t* m = (uint8_t*)&m_32;
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805

        m[0] = (uint8_t)(rand.x > ratio);
        m[1] = (uint8_t)(rand.y > ratio);
        m[2] = (uint8_t)(rand.z > ratio);
        m[3] = (uint8_t)(rand.w > ratio);

        data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
        data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
        data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
        data_h_1.y = __float2half(data_h_1.y * scale * m[3]);

        data_h_0.x += residual_h_0.x;
        data_h_0.y += residual_h_0.y;
        data_h_1.x += residual_h_1.x;
        data_h_1.y += residual_h_1.y;

        float2 result_f;
        __half2* result_h = reinterpret_cast<__half2*>(&result_f);

        result_h[0] = __float22half2_rn(data_h_0);
        result_h[1] = __float22half2_rn(data_h_1);

806
807
808
809
810
811
812
813
814
815
        out_cast[j] = result_f;
        mask_32[j] = m_32;
    }
    int high_index =
        ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
    if (N > high_index) {
        float4 rand = curand_uniform4(&state);
        float* rand_data = &(rand.x);
        int k = 0;
        for (int i = high_index; i < N; i++) {
816
            float x_data = (float)input[i] + (float)bias[i % dim];
817
818
819
820
821
822
823
            uint8_t m = (uint8_t)(rand_data[k++] > ratio);
            x_data = x_data * scale * m;
            x_data += (float)residual[i];

            out[i] = __float2half(x_data);
            mask[i] = m;
        }
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    }
}

template <typename T>
void launch_dropout(T* out,
                    const T* input,
                    const T* residual,
                    const T* bias,
                    uint8_t* mask,
                    int batch,
                    int dim,
                    float ratio,
                    cudaStream_t stream)
{
838
839
840
841
842
    assert(unroll_factor == 4);

    int total_count = batch * dim / unroll_factor;
    dim3 grid_dim = DS_GET_BLOCKS(total_count);
    dim3 block_dim = DS_CUDA_NUM_THREADS;
843
844
845
846
847

    uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
    std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);

    dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
848
        total_count, dim, ratio, input, residual, bias, out, mask, seed);
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
}

template void launch_dropout(float*,
                             const float*,
                             const float* residual,
                             const float* bias,
                             uint8_t* mask,
                             int batch,
                             int dim,
                             float ratio,
                             cudaStream_t stream);
template void launch_dropout(__half*,
                             const __half*,
                             const __half* residual,
                             const __half* bias,
                             uint8_t* mask,
                             int batch,
                             int dim,
                             float ratio,
                             cudaStream_t stream);