mlp_cuda.cu 53.9 KB
Newer Older
1
2
// New MLP with denorm mitigation only for backprop

3
4
5
6
7
8
9
10
11
12
13
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
flyingdown's avatar
flyingdown committed
14
#include "utils.h"
15

16
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
17
18
// includes cublaslt
#include <cublasLt.h>
19
#endif
20
21
22
23
24
// constants for fused bias+relu kernel
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
25

26

flyingdown's avatar
flyingdown committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// move to a header later on
#define ILP 4
template<typename T>
__host__ __device__ __forceinline__ bool is_aligned(T* p){
  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename T>
__device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

// Keep ReLU in float only. When using half, cast to float before calling.
__device__ __inline__ float relu(float a) {
  float retf = max(a, 0.f);
  return (retf);
}

57
58
59
60
61
62
// Keep Sigmoid in float only. When using half, cast to float before calling.
__device__ __inline__ float sigmoid(float a) {
  float retf = 1.f / (1.f + expf(-a));
  return (retf);
}

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float* alpha,
    const double* A,
    int lda,
    const double* B,
    int ldb,
    const float* beta,
    double* C,
78
79
    int ldc,
    int flag) {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#ifdef __HIP_PLATFORM_HCC__
  return rocblas_gemm_ex(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      rocblas_datatype_f64_r,
      lda,
      B,
      rocblas_datatype_f64_r,
      ldb,
      beta,
      C,
      rocblas_datatype_f64_r,
      ldc,
      C,
      rocblas_datatype_f64_r,
      ldc,
      rocblas_datatype_f64_r,
      rocblas_gemm_algo_standard,
      0,
105
      flag);  
106
#else
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_64F,
      lda,
      B,
      CUDA_R_64F,
      ldb,
      beta,
      C,
      CUDA_R_64F,
      ldc,
      CUDA_R_64F,
      CUBLAS_GEMM_DEFAULT);
127
#endif
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
}

// FP32 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float* alpha,
    const float* A,
    int lda,
    const float* B,
    int ldb,
    const float* beta,
    float* C,
145
146
    int ldc,
    int flag) {
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#ifdef __HIP_PLATFORM_HCC__
  return rocblas_gemm_ex(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      rocblas_datatype_f32_r,
      lda,
      B,
      rocblas_datatype_f32_r,
      ldb,
      beta,
      C,
      rocblas_datatype_f32_r,
      ldc,
      C,
      rocblas_datatype_f32_r,
      ldc,
      rocblas_datatype_f32_r,
      rocblas_gemm_algo_standard,
      0,
172
      flag);
173
174

#else
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_32F,
      lda,
      B,
      CUDA_R_32F,
      ldb,
      beta,
      C,
      CUDA_R_32F,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT);
195
#endif
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
}

// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float* alpha,
    const at::Half* A,
    int lda,
    const at::Half* B,
    int ldb,
    float* beta,
    at::Half* C,
213
214
    int ldc,
    int flag) {
215
#ifdef __HIP_PLATFORM_HCC__
flyingdown's avatar
flyingdown committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
  if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
    half h_alpha = __float2half(*alpha);
    half h_beta = __float2half(*beta);
    return rocblas_gemm_ex(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        /* alpha */ &h_alpha,
        A,
        rocblas_datatype_f16_r,
        lda,
        B,
        rocblas_datatype_f16_r,
        ldb,
        /* beta */ &h_beta,
        C,
        rocblas_datatype_f16_r,
        ldc,
        C,
        rocblas_datatype_f16_r,
        ldc,
        /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
        rocblas_gemm_algo_standard,
        0,
        flag);
  } else {
    return rocblas_gemm_ex(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        alpha,
        A,
        rocblas_datatype_f16_r,
        lda,
        B,
        rocblas_datatype_f16_r,
        ldb,
        beta,
        C,
        rocblas_datatype_f16_r,
        ldc,
        C,
        rocblas_datatype_f16_r,
        ldc,
        rocblas_datatype_f32_r,
        rocblas_gemm_algo_standard,
        0,
        flag);
  }
271
#else
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16F,
      lda,
      B,
      CUDA_R_16F,
      ldb,
      beta,
      C,
      CUDA_R_16F,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
292
#endif
293
}
294
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
int mlp_gemm_lt(
    cublasLtHandle_t ltHandle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float *alpha, /* host pointer */
    const at::Half* A,
    int lda,
    const at::Half* B,
    int ldb,
    float *beta, /* host pointer */
    at::Half* C,
    int ldc,
    void *workspace,
    size_t workspaceSize,
    cudaStream_t stream,
    bool use_bias,
    bool use_relu,
    const void* bias) {
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (use_bias) {
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
    if (use_relu) {
      epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
  } else {
    if (use_relu) {
      epilogue = CUBLASLT_EPILOGUE_RELU;
    }
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
                          alpha,
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
                          beta,
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          &heuristicResult.algo,
                          workspace,
                          workspaceSize,
                          stream);

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

int mlp_gemm_lt(
    cublasLtHandle_t ltHandle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float *alpha, /* host pointer */
    const double* A,
    int lda,
    const double* B,
    int ldb,
    float *beta, /* host pointer */
    double* C,
    int ldc,
    void *workspace,
    size_t workspaceSize,
    cudaStream_t stream,
    bool use_bias,
    bool use_relu,
    const void* bias) {
  return 1;
}

int mlp_gemm_lt(
    cublasLtHandle_t ltHandle,
    cublasOperation_t transa,
    cublasOperation_t transb,
    int m,
    int n,
    int k,
    float *alpha, /* host pointer */
    const float *A,
    int lda,
    const float *B,
    int ldb,
    float *beta, /* host pointer */
    float *C,
    int ldc,
    void *workspace,
    size_t workspaceSize,
    cudaStream_t stream,
    bool use_bias,
    bool use_relu,
    const void* bias) {
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (use_bias) {
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
    if (use_relu) {
      epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
  } else {
    if (use_relu) {
      epilogue = CUBLASLT_EPILOGUE_RELU;
    }
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
    &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
    &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }

  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
                          alpha,
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
                          beta,
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          &heuristicResult.algo,
                          workspace,
                          workspaceSize,
                          stream);

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
555
#endif
556

557
// Bias ADD. Assume input X is [features x batch size], column major.
558
// Bias is one 'features' long vector, with implicit broadcast.
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
template <typename T>
__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
  T r_x[ILP];
  T r_b[ILP];
  if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      int row = tid % (features / ILP);
      load_store(r_x, X, 0 , tid);
      load_store(r_b, b, 0 , row);
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
        r_x[ii] = bias_sum;
      }
      load_store(X, r_x, tid , 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          int row = tid % features;
          r_x[ii] = X[idx];
          r_b[ii] = b[row];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
        r_x[ii] = bias_sum;
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          X[idx] = r_x[ii];
        }
      }
    }
  }
}

// Bias ADD + ReLU. Assume input X is [features x batch size], column major.
// Activation support fuesed ReLU. Safe to call in-place.
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
template <typename T>
__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
  T r_x[ILP];
  T r_b[ILP];
  if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      int row = tid % (features / ILP);
      load_store(r_x, X, 0 , tid);
      load_store(r_b, b, 0 , row);
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
        r_x[ii] = relu(bias_sum);
      }
      load_store(X, r_x, tid , 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          int row = tid % features;
          r_x[ii] = X[idx];
          r_b[ii] = b[row];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
        r_x[ii] = relu(bias_sum);
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          X[idx] = r_x[ii];
        }
      }
    }
  }
}

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_fprop(T *X, uint batch_size, uint features) {
  T r_x[ILP];
  if(is_aligned(X) && features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      load_store(r_x, X, 0 , tid);
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        r_x[ii] = relu(static_cast<float>(r_x[ii]));
      }
      load_store(X, r_x, tid , 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          r_x[ii] = X[idx];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        r_x[ii] = relu(static_cast<float>(r_x[ii]));
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          X[idx] = r_x[ii];
        }
      }
    }
  }
}

// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
  T r_x[ILP];
  if(is_aligned(X) && features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      load_store(r_x, X, 0 , tid);
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
      }
      load_store(X, r_x, tid , 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          r_x[ii] = X[idx];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          X[idx] = r_x[ii];
        }
      }
    }
  }
}

// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
  T r_dy[ILP];
  T r_y[ILP];
  if(is_aligned(dY) &&
     is_aligned(Y) &&
     is_aligned(dX) &&
     features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      load_store(r_dy, dY, 0 , tid);
      load_store(r_y, Y, 0 , tid);
#pragma unroll
      for(int ii=0;ii<ILP;ii++){
        if ((float)r_y[ii] <= 0.f)
          r_dy[ii] = 0;
      }
      load_store(dX, r_dy, tid, 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          r_dy[ii] = dY[idx];
          r_y[ii] = Y[idx];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        if ((float)r_y[ii] <= 0.f)
          r_dy[ii] = 0;
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          dX[idx] = r_dy[ii];
        }
      }
    }
  }
}

// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
  T r_dy[ILP];
  T r_y[ILP];
  if(is_aligned(dY) &&
     is_aligned(Y) &&
     is_aligned(dX) &&
     features % ILP ==0) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
      load_store(r_dy, dY, 0 , tid);
      load_store(r_y, Y, 0 , tid);
#pragma unroll
      for(int ii=0;ii<ILP;ii++){
        float grad_out = r_dy[ii];
        float out = r_y[ii];
        float grad_i = out * ( 1.f - out) * grad_out;
        r_dy[ii] = grad_i;
      }
      load_store(dX, r_dy, tid, 0);
    }
  } else {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          r_dy[ii] = dY[idx];
          r_y[ii] = Y[idx];
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        float grad_out = r_dy[ii];
        float out = r_y[ii];
        float grad_i = out * ( 1.f - out) * grad_out;
        r_dy[ii] = grad_i;
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++) {
        int idx = tid + ii * blockDim.x * gridDim.x;
        if(idx < features * batch_size) {
          dX[idx] = r_dy[ii];
        }
      }
    }
  }
}

831
// Compute grid size for pointwise backward kernel.
832
// block_x/y is total elment being handled per block, not number of threads
833
834
835
void get_biasAddRelu_bprop_grid_size(
    int yfeat,
    int batch_size,
836
837
    int block_x,
    int block_y,
838
839
    int* grid_x,
    int* grid_y) {
840
841

  *grid_x = (yfeat + block_x - 1) / block_x;
842
843
  // Get number of SMs for efficient reduction.
  int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
844
  // can switch to occupancy calculation. use 4 below now for sm_70
845
  int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
846
847
848
849
850
  // block_y should be from minimal work per thread
  int nRedSplits = (batch_size + block_y - 1) / block_y;
  // increase number of elem per thread redcution to not launch more than enough
  // kernel adjust work, so here we just launch max block
  *grid_y = std::min(nRedSplits, max_blocks_y);
851
852
853
  return;
}

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR>
__global__ void biasAdd_bprop(
    T* dY,
    int features,
    int batch_size,
    volatile float* intermediate,
    int* semaphores,
    T* db) {
  // The feature that this thread is responsible for
  int f = blockIdx.x * blockDim.x + threadIdx.x;

  // Compute the span this thread is responsible for
  // For this block
  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  int b_nStart = blockIdx.y * b_chunkSize;
  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  // For this thread
  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  int nStart = threadIdx.y * chunkSize + b_nStart;
  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;

  volatile float* out = intermediate + blockIdx.y * features;

  // Flag to trigger last reduction.
  __shared__ bool isLastBlock;
  // we know block size for now
  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];

  // Accumulate db in FP32 always
  float db_local = 0;
  if (f < features) {
    int nidx = 0;
    // Handle non-multiple of UNROLL_FACTOR residue
    for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
890
      int64_t row, col, flat_idx;
891
892
893
894
895
896
897
898
      row = f;
      col = nStart + nidx;
      flat_idx = col * features + row;
      db_local += (float)dY[flat_idx];
    }

    // Handle meat of work
    for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
899
      int64_t row, col, flat_idx;
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
      row = f;
      col = nStart + nidx;
      flat_idx = col * features + row;
#pragma unroll 4
      for (int u = 0; u < UNROLL_FACTOR; u++) {
        db_local += (float)dY[flat_idx];
        flat_idx += features;
      }
    }

    // naive block reduction on y-dim
    int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
    smem[linear_idx] = db_local;
  }
  __syncthreads();
  if (f < features) {
    if(threadIdx.y == 0) {
      for(int yidx = 1; yidx < blockDim.y; yidx++){
        db_local += smem[yidx * blockDim.x + threadIdx.x];
      }

      // block result is in db_local now for all threadIdx.y == 0
      // Write out partial result
      out[f] = db_local;
    }
  }
  __threadfence();
  __syncthreads();

  // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  // Only thread (0,0) calls this
  if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
    unsigned int sum_idx;
    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
    isLastBlock = (sum_idx == (gridDim.y - 1));
  }
  __syncthreads();

  db_local = 0;
  // No block reduction for now, only thread (*,0) do grid reduction
  if (isLastBlock && f < features) {
    if(threadIdx.y == 0) {
      for (int n = 0; n < gridDim.y; n++) {
        int row, col;
        row = f;
        col = n;
        db_local += (float)(intermediate[col * features + row]);
      }
      db[f] = (T)db_local;
    }
  }
}

953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR>
__global__ void biasAddRelu_bprop(
    T* Y,
    T* dY,
    int features,
    int batch_size,
    T* dX,
    volatile float* intermediate,
    int* semaphores,
    T* db) {
  // The feature that this thread is responsible for
  int f = blockIdx.x * blockDim.x + threadIdx.x;

968
969
970
971
972
973
974
975
976
977
  // Compute the span this thread is responsible for
  // For this block
  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  int b_nStart = blockIdx.y * b_chunkSize;
  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  // For this thread
  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  int nStart = threadIdx.y * chunkSize + b_nStart;
  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;

978
979
980
981
  volatile float* out = intermediate + blockIdx.y * features;

  // Flag to trigger last reduction.
  __shared__ bool isLastBlock;
982
983
  // we know block size for now
  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026

  // Accumulate db in FP32 always
  float db_local = 0;
  if (f < features) {
    int nidx = 0;
    // Handle non-multiple of UNROLL_FACTOR residue
    for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
      int row, col, flat_idx;
      row = f;
      col = nStart + nidx;
      flat_idx = col * features + row;
      T y_val = Y[flat_idx];
      T dy_val = dY[flat_idx];
      T dx_val;
      if ((float)y_val > 0.f)
        dx_val = dy_val;
      else
        dx_val = 0;
      dX[flat_idx] = dx_val;
      db_local += (float)dx_val;
    }

    // Handle meat of work
    for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
      int row, col, flat_idx;
      row = f;
      col = nStart + nidx;
      flat_idx = col * features + row;
#pragma unroll 4
      for (int u = 0; u < UNROLL_FACTOR; u++) {
        T y_val = Y[flat_idx];
        T dy_val = dY[flat_idx];
        T dx_val;
        if ((float)y_val > 0.f)
          dx_val = dy_val;
        else
          dx_val = 0;
        dX[flat_idx] = dx_val;
        db_local += (float)dx_val;
        flat_idx += features;
      }
    }

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    // naive block reduction on y-dim
    int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
    smem[linear_idx] = db_local;
  }
  __syncthreads();
  if (f < features) {
    if(threadIdx.y == 0) {
      for(int yidx = 1; yidx < blockDim.y; yidx++){
        db_local += smem[yidx * blockDim.x + threadIdx.x];
      }

      // block result is in db_local now for all threadIdx.y == 0
      // Write out partial result
      out[f] = db_local;
    }
1042
1043
1044
1045
  }
  __threadfence();
  __syncthreads();

1046
1047
1048
  // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  // Only thread (0,0) calls this
  if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
1049
1050
1051
1052
1053
1054
1055
    unsigned int sum_idx;
    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
    isLastBlock = (sum_idx == (gridDim.y - 1));
  }
  __syncthreads();

  db_local = 0;
1056
  // No block reduction for now, only thread (*,0) do grid reduction
1057
  if (isLastBlock && f < features) {
1058
1059
1060
1061
1062
1063
1064
1065
    if(threadIdx.y == 0) {
      for (int n = 0; n < gridDim.y; n++) {
        int row, col;
        row = f;
        col = n;
        db_local += (float)(intermediate[col * features + row]);
      }
      db[f] = (T)db_local;
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
    }
  }
}

// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR>
__global__ void biasAddRelu_bprop_aligned(
    T* Y,
    T* dY,
    int features,
    int batch_size,
    T* dX,
    volatile float* intermediate,
    int* semaphores,
    T* db) {
  // The feature that this thread is responsible for
  int f = blockIdx.x * blockDim.x + threadIdx.x;

1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
  // Compute the span this thread is responsible for
  // For this block
  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  int b_nStart = blockIdx.y * b_chunkSize;
  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  // For this thread
  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  int nStart = threadIdx.y * chunkSize + b_nStart;
  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;

1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
  volatile float* out = intermediate + blockIdx.y * features;

  // Flag to trigger last reduction.
  __shared__ bool isLastBlock;

  // Accumulate db in FP32 always
  float db_local[ILP];
  T r_y[ILP];
  T r_dy[ILP];
#pragma unroll
  for(int ii=0;ii<ILP;ii++){
    db_local[ii] = 0.f;
  }

  // f always <= features in this case
  //if (f < features) {
  int nidx = 0;

  // Handle non-multiple of UNROLL_FACTOR residue
  for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
    int row, col, flat_idx;
    row = f;
    col = nStart + nidx;
    flat_idx = col * features / ILP + row;

    load_store(r_y, Y, 0, flat_idx);
    load_store(r_dy, dY, 0, flat_idx);
#pragma unroll
    for(int ii=0;ii<ILP;ii++){
      if ((float)r_y[ii] <= 0.f)
        r_dy[ii] = 0;
      db_local[ii] += (float)r_dy[ii];
    }
    load_store(dX, r_dy, flat_idx, 0);
  }

  // Handle meat of work
  for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
    int row, col, flat_idx;
    row = f;
    col = nStart + nidx;
    flat_idx = col * features / ILP + row; // total threads in x == features/ILP
#pragma unroll
    for (int u = 0; u < UNROLL_FACTOR; u++) {
      load_store(r_y, Y, 0, flat_idx);
      load_store(r_dy, dY, 0, flat_idx);
#pragma unroll
      for(int ii=0;ii<ILP;ii++){
        if ((float)r_y[ii] <= 0.f)
          r_dy[ii] = 0;
        db_local[ii] += (float)r_dy[ii];
      }
      load_store(dX, r_dy, flat_idx, 0);
      flat_idx += features/ILP;
    }
  }

1152
1153
1154
1155
1156
  // we know block size for now
  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
  // naive block reduction on y-dim
  int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
  float* smem_out = smem + ILP * linear_idx;
1157
#pragma unroll
1158
1159
  for(int ii=0;ii<ILP;ii++){
    smem_out[ii] = db_local[ii]; // reuse local dy buffer
1160
  }
1161
1162
1163
1164
1165
1166
1167
1168
1169
  __syncthreads();
  if(threadIdx.y == 0) {
    for(int yidx = 1; yidx < blockDim.y; yidx++){
      float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
#pragma unroll
      for(int ii=0;ii<ILP;ii++){
        db_local[ii] += smem_in[ii]; // reuse local dy buffer
      }
    }
1170

1171
1172
1173
1174
1175
1176
1177
1178
1179
    // block result is in db_local now for all threadIdx.y == 0
    if(gridDim.y == 1) {
#pragma unroll
      for(int ii=0;ii<ILP;ii++){
        r_dy[ii] = db_local[ii]; // reuse local dy buffer
      }
      load_store(db, r_dy, f, 0);
      return;
    }
1180

1181
1182
1183
    // Write out partial result
    load_store(out, db_local, f, 0);
  }
1184
1185
1186
  __threadfence();
  __syncthreads();

1187
1188
1189
  // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  // Only thread (0,0) calls this
  if (threadIdx.x == 0 && threadIdx.y == 0) {
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    unsigned int sum_idx;
    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
    isLastBlock = (sum_idx == (gridDim.y - 1));
  }
  __syncthreads();

#pragma unroll
  for(int ii=0;ii<ILP;ii++){
    db_local[ii] = 0.f;
  }
  float r_db[ILP];
1201
1202

  // No block reduction for now, only thread (*,0) do grid reduction
1203
  if (isLastBlock) {
1204
1205
1206
1207
1208
1209
    if(threadIdx.y == 0){
      for (int n = 0; n < gridDim.y; n++) {
        int row, col;
        row = f;
        col = n;
        load_store(r_db, intermediate, 0, col * features / ILP + row);
1210
#pragma unroll
1211
1212
1213
        for(int ii=0;ii<ILP;ii++){
          db_local[ii] += r_db[ii];
        }
1214
1215
      }
#pragma unroll
1216
1217
1218
1219
      for(int ii=0;ii<ILP;ii++){
        r_dy[ii] = db_local[ii]; // reuse local dy buffer
      }
      load_store(db, r_dy, f, 0);
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    }
  }
}

// Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting
// offset 0. The last Y value is, of course, stored in the user provided output buffer.
void get_y_offsets(
    int batch_size,
    int num_layers,
    const int* output_features,
    int* y_start_offsets) {
  y_start_offsets[0] = 0;
  for (int i = 1; i < num_layers; i++) {
    y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];
  }
}

// Returns the reserved space (in elements) needed for the MLP
1238
size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
  size_t res_space = 0;
  // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
  // for all 'i' in [0, num_layers-1)
  for (int l = 0; l < num_layers; l++) {
    res_space += output_features[l] * batch_size;
  }
  return res_space;
}

// Returns the size of all fprop activations combined
1249
size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
  size_t acts_size = 0;
  for (int l = 0; l < num_layers; l++) {
    acts_size += output_features[l] * batch_size;
  }
  return acts_size;
}

#if 0
// Returns the work space (in elements) needed for the MLP bprop.
size_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {
    /*
       Workspace is partitioned as
       DY_GEMMs : DX_GEMMs
    */
    size_t work_space = 0;

    // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
    // of biasReLU_bp and one for o/p of dgrad GEMM).
    work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);

    return work_space;
}
#endif

// Scratch space needed for reductions in number of elements
size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {
  size_t max_scratch_space = 0;
  // Loop over all layers to see which one needs the max scratch space
  for (int l = 0; l < num_layers; l++) {
1279
1280
1281
1282
1283
1284
1285
1286
1287
    // need to find max(aligned, not_aligned)
    int tmp, res0, res1;

    int block_x = BIAS_RELU_BW_NTHREADS_X;
    int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
    get_biasAddRelu_bprop_grid_size(
      output_features[l], batch_size, block_x, block_y, &tmp, &res0);

    block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
1288
    get_biasAddRelu_bprop_grid_size(
1289
1290
1291
1292
      output_features[l], batch_size, block_x, block_y, &tmp, &res1);

    max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
    max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
  }

  return max_scratch_space;
}

// Buffer for semaphores
size_t get_semaphores_size(int num_layers, const int* output_features) {
  // Upper bound on semaphores is one per feature for the layer
  // with the most features.
  int max_features = 0;
  for (int l = 0; l < num_layers; l++) {
    max_features = std::max(max_features, output_features[l]);
  }
  return (size_t)max_features;
}

// Returns the work space (in elements) needed for the MLP bprop.
template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {
  size_t work_space = 0;

  // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
  // of biasReLU_bp and one for o/p of dgrad GEMM).
  work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);
  work_space +=
      get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);
  work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);

  return work_space;
}

// Returns pointers to each segment of the workspace
template <typename T>
void partition_mlp_bp_workspace(
    int batch_size,
    int num_layers,
    const int* output_features,
    void* work_space,
    T** dy_gemms,
    T** dx_gemms,
    float** db_scratch,
    int** semaphores) {
  /*
     Workspace is partitioned as
     DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES
  */
  // Start address where dy_gemm tensors are stored
  *dy_gemms = reinterpret_cast<T*>(work_space);
  // Start address where dx_gemm tensors are stored
  *dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);
  // Start address where db intermediate tensors are stored
  *db_scratch = reinterpret_cast<float*>(
      *dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));
  // Start address of semaphores
  *semaphores = reinterpret_cast<int*>(
      *db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));

  return;
}

// Does a simple MLP fprop (GEMM+bias+ReLU).
// Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed
// to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and
// must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer
// 'i'.
template <typename T>
int mlp_fp(
    T* X,
    int input_features,
    int batch_size,
    T** WPtr,
    int num_layers,
    int* output_features,
    T** BPtr,
    T* Y,
1368
1369
    T* reserved_space,
    int use_bias,
1370
1371
    int activation,
    void* lt_workspace) {
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
  T *weight, *input, *output, *bias;
  T *reserved_space_x, *reserved_space_y;
  reserved_space_x = NULL;
  reserved_space_y = reserved_space;

  // Get cublas handle from Pytorch
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  // Get the stream from cublas handle to reuse for biasReLU kernel.
  cudaStream_t stream;
  cublasGetStream(handle, &stream);

  for (int layer = 0; layer < num_layers; layer++) {
    weight = WPtr[layer];
    input = (layer == 0) ? X : reserved_space_x;
    output = (layer == num_layers - 1) ? Y : reserved_space_y;
1387
1388
1389
    if (use_bias) {
      bias = BPtr[layer];
    }
1390
1391
1392
1393
1394
1395
    int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
    int ofeat = output_features[layer];

    float one = 1.f;
    float zero = 0.f;

1396
1397
    // try with cublaslt first for supported case with valid handle
    int cublaslt_status = 1;
1398
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
    if(activation < 1){
        cublaslt_status = mlp_gemm_lt(
          //ltHandle,
          (cublasLtHandle_t)handle,
          CUBLAS_OP_T,
          CUBLAS_OP_N,
          ofeat,
          batch_size,
          ifeat,
          &one,
          weight,
          ifeat,
          input,
          ifeat,
          &zero,
          output,
          ofeat,
          lt_workspace,
          1 << 22,
          stream,
          use_bias == 1,
          activation == 1,
          bias);
1422
    }
1423
#endif
1424
1425
1426
1427
1428
1429

    // if cublaslt failed or not executed, fallback to cublas
    if (cublaslt_status != 0) {
      cublasStatus_t cublas_status;
      // Call GEMM: fprop is Y = W'X
      cublas_status = mlp_gemm(
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
        handle,
        CUBLAS_OP_T,
        CUBLAS_OP_N,
        ofeat,
        batch_size,
        ifeat,
        &one,
        weight,
        ifeat,
        input,
        ifeat,
        &zero,
        output,
1443
1444
        ofeat,
        int(0)); // Do nothing for forward prop
1445

1446
1447
1448
      if (cublas_status != CUBLAS_STATUS_SUCCESS) {
        printf("GEMM fprop failed with %d\n", cublas_status);
        return 1;
1449
      }
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476

      const uint &input_size = ofeat;
      int num_blocks = 0;
      int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
      // Call biasReLU
      if(use_bias == 1) {
        if (activation == 0) { // no activation
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
        } else if (activation == 1) { // relu
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
        } else if (activation == 2) { // sigmoid
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
        }
      } else {
        // don't need to do anything in case of no activation and no bias
        if (activation == 1) { // relu
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
        } else if (activation == 2) { // sigmoid
          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
          Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
        }
1477
1478
      }
    }
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
    // Set current output as next layer input
    reserved_space_x = reserved_space_y;
    // Set next layer output
    reserved_space_y += ofeat * batch_size;
  }

  return 0;
}

// Does a simple MLP bprop (GEMM+bias+ReLU).
// Needs reserved space to come back exactly as it was populated in fprop.
// Does dgrad and wgrad sequentially.
template <typename T>
int mlp_bp(
    T* X,
    T* Y,
    int input_features,
    int batch_size,
    T** WPtr,
    int num_layers,
    int* output_features,
    T* dY,
    T* reserved_space,
    T* work_space,
    T* dX,
    T** dwPtr,
1505
1506
1507
1508
    T** dbPtr,
    bool requires_grad,
    int use_bias,
    int activation) {
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
  T* weight;
  T *dweight, *dx, *dy, *dbias;
  T *x, *y;

  // Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away
  // after bp call.
  T* dy_gemm_base;
  // Where the dx after GEMM is stored.
  T* dx_gemm_base;
  // Where partial reduction results are stored.
  float* db_scratch;
  // Semaphores for reduction.
  int* semaphores;

  partition_mlp_bp_workspace<T>(
      batch_size,
      num_layers,
      output_features,
      work_space,
      &dy_gemm_base,
      &dx_gemm_base,
      &db_scratch,
      &semaphores);

  size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);

  // Get cublas handle from Pytorch
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  // Get the stream from cublas handle to reuse for biasReLU kernel.
  cudaStream_t stream;
  cublasGetStream(handle, &stream);
1540
1541
1542
1543
1544
  int flag = 0;
  #ifdef __HIP_PLATFORM_HCC__
    #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
    #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
    #if USE_GEMM_FLAGS_FP16_ALT_IMPL
1545
1546
      #ifdef BACKWARD_PASS_GUARD
        flag = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
1547
      #endif
1548
1549
1550
    #endif
  #endif
  
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
  int* y_offsets = (int*)malloc(num_layers * sizeof(int));
  get_y_offsets(batch_size, num_layers, output_features, y_offsets);

  for (int layer = num_layers - 1; layer >= 0; layer--) {
    weight = WPtr[layer];
    dweight = dwPtr[layer];

    // x is read from reserved space
    x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];
    // dx is written in workspace for all but layer==0
    dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];

    // y is read from reserved space
    y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];
    // dx from layer+1
    dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];
    // dy_gemm is written to and read immediately
    T* dy_gemm = dy_gemm_base + y_offsets[layer];

    dbias = dbPtr[layer];
    int xfeat = (layer == 0) ? input_features : output_features[layer - 1];
    int yfeat = output_features[layer];

    float one = 1.f;
    float zero = 0.f;

1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
    if (use_bias == 1) {
      if (activation == 0) { // no acitvation
        // bgrad
        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
        int grid_x, grid_y;
        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);

        int block_x = BIAS_RELU_BW_NTHREADS_X;
        int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
        get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
        dim3 grid(grid_x, grid_y);
        biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
          dy, yfeat, batch_size, db_scratch, semaphores, dbias);
        // bypass dgrad through reset pointer
        dy_gemm = dy;
      } else if (activation == 1) { // relu
        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
        int grid_x, grid_y;
        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);

        if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
           is_aligned(y) &&
           is_aligned(dy) &&
           is_aligned(dy_gemm) &&
           is_aligned(dbias)){
          int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
          int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
          get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
          dim3 grid(grid_x, grid_y);
          biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
            y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
        } else {
          int block_x = BIAS_RELU_BW_NTHREADS_X;
          int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
          get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
          dim3 grid(grid_x, grid_y);
          biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
            y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
        }
      } else if (activation == 2) { // sigmoid
        // activation backward
        int num_blocks = 0;
        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
        Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);

        // bgrad, from dy_gemm
        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
        int grid_x, grid_y;
        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);

        int block_x = BIAS_RELU_BW_NTHREADS_X;
        int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
        get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
        dim3 grid(grid_x, grid_y);
        biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
          dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
      }
    } else { // no bias below
      if (activation == 0) {
        // bypass dgrad through reset pointer
        dy_gemm = dy;
      } else if (activation == 1) { // relu
        int num_blocks = 0;
        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
        Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
      } else if (activation == 2) { // sigmoid
        int num_blocks = 0;
        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
        Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
      }
1650
1651
1652
1653
    }

    cublasStatus_t cublas_status;
    // Call GEMM dgrad
1654
1655
    if (layer > 0 || requires_grad == 1) {
      cublas_status = mlp_gemm(
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
        handle,
        CUBLAS_OP_N,
        CUBLAS_OP_N,
        xfeat,
        batch_size,
        yfeat,
        &one,
        weight,
        xfeat,
        dy_gemm,
        yfeat,
        &zero,
        dx,
1669
1670
        xfeat,
        flag); //
1671

1672
1673
1674
1675
      if (cublas_status != CUBLAS_STATUS_SUCCESS) {
        printf("GEMM dgrad failed with %d\n", cublas_status);
        return 1;
      }
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
    }

    // Call GEMM wgrad
    cublas_status = mlp_gemm(
        handle,
        CUBLAS_OP_N,
        CUBLAS_OP_T,
        xfeat,
        yfeat,
        batch_size,
        &one,
        x,
        xfeat,
        dy_gemm,
        yfeat,
        &zero,
        dweight,
1693
1694
        xfeat,
        flag); //
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714

    if (cublas_status != CUBLAS_STATUS_SUCCESS) {
      printf("GEMM wgrad failed with %d\n", cublas_status);
      return 1;
    }
  }

  return 0;
}

// Instantiate for floating point types
template int mlp_fp<float>(
    float* X,
    int input_features,
    int batch_size,
    float** WPtr,
    int num_layers,
    int* output_features,
    float** BPtr,
    float* Y,
1715
1716
    float* reserved_space,
    int use_bias,
1717
1718
    int activation,
    void* lt_workspace);
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732

template int mlp_bp<float>(
    float* X,
    float* Y,
    int input_features,
    int batch_size,
    float** WPtr,
    int num_layers,
    int* output_features,
    float* dY,
    float* reserved_space,
    float* work_space,
    float* dX,
    float** dwPtr,
1733
1734
1735
1736
    float** dbPtr,
    bool requires_grad,
    int use_bias,
    int activation);
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746

template int mlp_fp<at::Half>(
    at::Half* X,
    int input_features,
    int batch_size,
    at::Half** WPtr,
    int num_layers,
    int* output_features,
    at::Half** BPtr,
    at::Half* Y,
1747
1748
    at::Half* reserved_space,
    int use_bias,
1749
1750
    int activation,
    void* lt_workspace);
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764

template int mlp_bp<at::Half>(
    at::Half* X,
    at::Half* Y,
    int input_features,
    int batch_size,
    at::Half** WPtr,
    int num_layers,
    int* output_features,
    at::Half* dY,
    at::Half* reserved_space,
    at::Half* work_space,
    at::Half* dX,
    at::Half** dwPtr,
1765
1766
1767
1768
    at::Half** dbPtr,
    bool requires_grad,
    int use_bias,
    int activation);
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778

template int mlp_fp<double>(
    double* X,
    int input_features,
    int batch_size,
    double** WPtr,
    int num_layers,
    int* output_features,
    double** BPtr,
    double* Y,
1779
1780
    double* reserved_space,
    int use_bias,
1781
1782
    int activation,
    void* lt_workspace);
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796

template int mlp_bp<double>(
    double* X,
    double* Y,
    int input_features,
    int batch_size,
    double** WPtr,
    int num_layers,
    int* output_features,
    double* dY,
    double* reserved_space,
    double* work_space,
    double* dX,
    double** dwPtr,
1797
1798
1799
1800
    double** dbPtr,
    bool requires_grad,
    int use_bias,
    int activation);
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813

template size_t get_mlp_bp_workspace_in_bytes<float>(
    int batch_size,
    int num_layers,
    const int* output_features);
template size_t get_mlp_bp_workspace_in_bytes<at::Half>(
    int batch_size,
    int num_layers,
    const int* output_features);
template size_t get_mlp_bp_workspace_in_bytes<double>(
    int batch_size,
    int num_layers,
    const int* output_features);