fused_dense_cuda.cu 24.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#include <cublasLt.h>
#endif

// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
23
24
25
    int64_t m,
    int64_t n,
    int64_t k,
26
    const float* alpha,
27
28
29
30
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
31
32
    const float* beta,
    at::Half* C,
33
    int64_t ldc) {
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16F,
      lda,
      B,
      CUDA_R_16F,
      ldb,
      beta,
      C,
      CUDA_R_16F,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
61
62
63
    int64_t m,
    int64_t n,
    int64_t k,
64
    const float* alpha,
65
66
67
68
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
69
70
    const float* beta,
    at::BFloat16* C,
71
    int64_t ldc) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16BF,
      lda,
      B,
      CUDA_R_16BF,
      ldb,
      beta,
      C,
      CUDA_R_16BF,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600

96
97
template <typename Dtype>
int gemm_bias_act_lt(
98
99
    cublasOperation_t transa,
    cublasOperation_t transb,
100
101
102
103
104
105
106
107
108
109
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const Dtype* bias,
    Dtype* C,
110
    int64_t ldc,
111
112
    void* pre_act,
    bool is_gelu,
113
114
115
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize
116
117
118
119
120
121
122
123
124
125
    ) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bias_act_lt only supports fp16 and bf16");
  bool save_pre_act = pre_act != nullptr;
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());

126
127
128
129
130
131
132
133
134
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
135
136
137
138
139
  // constexpr int requestedAlgoCount = 1;
  // cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = is_gelu
      ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
      : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
140
141
142
143
144
145
146
147
148
149
150

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

151
152
  if (save_pre_act) {
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
153
154
155
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  }

Tri Dao's avatar
Tri Dao committed
156
  if (bias != nullptr) {
157
158
159
160
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
161
162
163
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);
Tri Dao's avatar
Tri Dao committed
164
  } else {
165
166
167
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
168
169
170
171
172
173
174
175
176
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
177
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
178
179
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
180
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
181
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
182
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
202
    // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
203
204
205
206
207
208
209
210
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
211
                          &alpha,
212
213
214
215
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
216
                          &beta,
217
218
219
220
221
222
223
224
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          // &heuristicResult.algo,
                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
                          &heuristicResult[heuristic].algo,
                          // NULL,
225
                          lt_workspace,
226
                          workspaceSize,
227
                          at::cuda::getCurrentCUDAStream());
228
229
230
231
232
233
234

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

235
template int gemm_bias_act_lt(
236
237
    cublasOperation_t transa,
    cublasOperation_t transb,
238
239
240
241
242
243
244
245
246
247
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const at::Half* bias,
    at::Half* C,
248
    int64_t ldc,
249
250
    void* pre_act,
    bool is_gelu,
251
252
253
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize);
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
template int gemm_bias_act_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const at::BFloat16* bias,
    at::BFloat16* C,
    int64_t ldc,
    void* pre_act,
    bool is_gelu,
271
272
273
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize);
274

275
template <typename Dtype>
276
277
278
int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
279
280
281
282
283
284
285
286
287
288
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    Dtype* C,
    int64_t ldc,
289
290
291
    Dtype* bgrad,
    void *lt_workspace,
    size_t workspaceSize) {
292
293
294
295
296
297
298
299
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

Tri Dao's avatar
Tri Dao committed
320
  if (bgrad != nullptr) {
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
335
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
336
337
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
338
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
339
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
340
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
368
                          &alpha,
369
370
371
372
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
373
                          &beta,
374
375
376
377
378
379
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          NULL,
380
                          lt_workspace,
381
                          workspaceSize,
382
                          at::cuda::getCurrentCUDAStream());
383
384
385
386
387
388
389

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

390
391

template int gemm_bgradb_lt(
392
393
    cublasOperation_t transa,
    cublasOperation_t transb,
394
395
396
397
398
399
400
401
402
403
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    at::Half* C,
    int64_t ldc,
404
405
406
    at::Half* bgrad,
    void *lt_workspace,
    size_t workspaceSize);
407

408
409
410
411
412
413
414
415
416
417
418
419
420
template int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    at::BFloat16* C,
    int64_t ldc,
421
422
423
    at::BFloat16* bgrad,
    void *lt_workspace,
    size_t workspaceSize);
424

425
426
template <typename Dtype>
int gemm_dact_bgradb_lt(
427
428
    cublasOperation_t transa,
    cublasOperation_t transb,
429
430
431
432
433
434
435
436
437
438
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const void* pre_act,
    Dtype* C,
439
    int64_t ldc,
440
441
    Dtype* bgrad,
    bool is_gelu,
442
443
444
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize) {
445
446
447
448
449
450
451
452
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_dact_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());

453
454
455
456
457
458
459
460
461
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
462
  cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
478
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
479
480
481
482
483
484
485
486
487
488
489
490
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
491
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
492
493
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
494
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
495
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
496
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
524
                          &alpha,
525
526
527
528
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
529
                          &beta,
530
531
532
533
534
535
536
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          &heuristicResult[heuristic].algo,
                          // NULL,
537
                          lt_workspace,
538
                          workspaceSize,
539
                          at::cuda::getCurrentCUDAStream());
540
541
542
543
544
545
546

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

547
template int gemm_dact_bgradb_lt(
548
549
    cublasOperation_t transa,
    cublasOperation_t transb,
550
551
552
553
554
555
556
557
558
559
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const void* pre_act,
    at::Half* C,
560
    int64_t ldc,
561
562
    at::Half* bgrad,
    bool is_gelu,
563
564
565
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize);
566

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
template int gemm_dact_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const void* pre_act,
    at::BFloat16* C,
    int64_t ldc,
    at::BFloat16* bgrad,
    bool is_gelu,
583
584
585
    int heuristic,
    void *lt_workspace,
    size_t workspaceSize);
586
587
588
589

#endif

template <typename T>
590
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
591
592
593
594
595
    const float alpha          = 1.0;
    const float beta_zero      = 0.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
    status = gemm_bgradb_lt(
596
    // (cublasLtHandle_t)handle,
597
598
599
600
601
    CUBLAS_OP_N,
    CUBLAS_OP_T,
    in_features,
    out_features,
    batch_size,
602
    alpha,
603
604
605
606
607
608
    input,
    in_features,
    d_output,
    out_features,
    d_weight,
    in_features,
609
610
611
    d_bias,
    lt_workspace,
    workspaceSize);
612
613
614
#endif

    if (status != 0){
615
        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        status = gemm_bias(
          handle,
          CUBLAS_OP_N,
          CUBLAS_OP_T,
          in_features,
          out_features,
          batch_size,
          &alpha,
          input,
          in_features,
          d_output,
          out_features,
          &beta_zero,
          d_weight,
          in_features);
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error
        // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341
        // at::cuda::blas::gemm<T>(
        //   'N',
        //   'T',
        //   in_features,
        //   out_features,
        //   batch_size,
        //   alpha,
        //   input,
        //   in_features,
        //   d_output,
        //   out_features,
        //   beta_zero,
        //   d_weight,
        //   in_features);
647
648
649
650
651
652
    }

    return status;
}

template <typename T>
653
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
654
655
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
656
    status = gemm_bias_act_lt(
657
658
659
660
661
    CUBLAS_OP_T,
    CUBLAS_OP_N,
    out_features,
    batch_size,
    in_features,
662
    /*alpha=*/1.0,
663
664
665
666
    weight,
    in_features,
    input,
    in_features,
667
    bias,
668
669
    output,
    out_features,
670
671
    pre_act,
    is_gelu,
672
673
674
    heuristic,
    lt_workspace,
    workspaceSize);
675
676
677
678
679
680
681
    return status;
#else
    return 1;
#endif
}

template <typename T>
682
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
683
684
685
    const float alpha          = 1.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
686
    status = gemm_dact_bgradb_lt(
687
688
    CUBLAS_OP_N,
    CUBLAS_OP_N,
Tri Dao's avatar
Tri Dao committed
689
    in_features,
690
691
    batch_size,
    out_features,
692
    alpha,
Tri Dao's avatar
Tri Dao committed
693
694
695
    weight,
    in_features,
    d_output,
696
    out_features,
697
    pre_act,
Tri Dao's avatar
Tri Dao committed
698
699
    d_input,
    in_features,
700
701
    d_bias,
    is_gelu,
702
703
704
    heuristic,
    lt_workspace,
    workspaceSize);
705
706
707
708
709
#endif
    return status;

}

710
711
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
712

713
714
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
715

716
717
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);