fused_dense_cuda.cu 24.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#include <cublasLt.h>
#endif

// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
23
24
25
    int64_t m,
    int64_t n,
    int64_t k,
26
    const float* alpha,
27
28
29
30
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
31
32
    const float* beta,
    at::Half* C,
33
    int64_t ldc) {
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16F,
      lda,
      B,
      CUDA_R_16F,
      ldb,
      beta,
      C,
      CUDA_R_16F,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
61
62
63
    int64_t m,
    int64_t n,
    int64_t k,
64
    const float* alpha,
65
66
67
68
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
69
70
    const float* beta,
    at::BFloat16* C,
71
    int64_t ldc) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16BF,
      lda,
      B,
      CUDA_R_16BF,
      ldb,
      beta,
      C,
      CUDA_R_16BF,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600

96
97
template <typename Dtype>
int gemm_bias_act_lt(
98
99
    cublasOperation_t transa,
    cublasOperation_t transb,
100
101
102
103
104
105
106
107
108
109
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const Dtype* bias,
    Dtype* C,
110
    int64_t ldc,
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    void* pre_act,
    bool is_gelu,
    int heuristic
    ) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bias_act_lt only supports fp16 and bf16");
  bool save_pre_act = pre_act != nullptr;
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
  size_t workspaceSize = 1024 * 1024;
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

130
131
132
133
134
135
136
137
138
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
139
140
141
142
143
  // constexpr int requestedAlgoCount = 1;
  // cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = is_gelu
      ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
      : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
144
145
146
147
148
149
150
151
152
153
154

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

155
156
  if (save_pre_act) {
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
157
158
159
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  }

Tri Dao's avatar
Tri Dao committed
160
  if (bias != nullptr) {
161
162
163
164
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
165
166
167
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);
Tri Dao's avatar
Tri Dao committed
168
  } else {
169
170
171
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
172
173
174
175
176
177
178
179
180
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
181
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
182
183
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
184
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
185
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
186
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
206
    // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
207
208
209
210
211
212
213
214
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
215
                          &alpha,
216
217
218
219
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
220
                          &beta,
221
222
223
224
225
226
227
228
229
230
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          // &heuristicResult.algo,
                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
                          &heuristicResult[heuristic].algo,
                          // NULL,
                          workspace,
                          workspaceSize,
231
                          at::cuda::getCurrentCUDAStream());
232
233
234
235
236
237
238

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

239
template int gemm_bias_act_lt(
240
241
    cublasOperation_t transa,
    cublasOperation_t transb,
242
243
244
245
246
247
248
249
250
251
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const at::Half* bias,
    at::Half* C,
252
    int64_t ldc,
253
254
255
    void* pre_act,
    bool is_gelu,
    int heuristic);
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
template int gemm_bias_act_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const at::BFloat16* bias,
    at::BFloat16* C,
    int64_t ldc,
    void* pre_act,
    bool is_gelu,
    int heuristic);
274

275
template <typename Dtype>
276
277
278
int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    Dtype* C,
    int64_t ldc,
    Dtype* bgrad) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
  size_t workspaceSize = 1024 * 1024;
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

Tri Dao's avatar
Tri Dao committed
324
  if (bgrad != nullptr) {
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
339
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
340
341
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
342
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
343
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
344
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
372
                          &alpha,
373
374
375
376
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
377
                          &beta,
378
379
380
381
382
383
384
385
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          NULL,
                          workspace,
                          workspaceSize,
386
                          at::cuda::getCurrentCUDAStream());
387
388
389
390
391
392
393

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

394
395

template int gemm_bgradb_lt(
396
397
    cublasOperation_t transa,
    cublasOperation_t transb,
398
399
400
401
402
403
404
405
406
407
408
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    at::Half* C,
    int64_t ldc,
    at::Half* bgrad);
409

410
411
412
413
414
415
416
417
418
419
420
421
422
423
template int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    at::BFloat16* C,
    int64_t ldc,
    at::BFloat16* bgrad);
424

425
426
template <typename Dtype>
int gemm_dact_bgradb_lt(
427
428
    cublasOperation_t transa,
    cublasOperation_t transb,
429
430
431
432
433
434
435
436
437
438
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const void* pre_act,
    Dtype* C,
439
    int64_t ldc,
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    Dtype* bgrad,
    bool is_gelu,
    int heuristic) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_dact_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
  size_t workspaceSize = 1024 * 1024;
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

457
458
459
460
461
462
463
464
465
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
466
  cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
482
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
483
484
485
486
487
488
489
490
491
492
493
494
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
495
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
496
497
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
498
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
499
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
500
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
528
                          &alpha,
529
530
531
532
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
533
                          &beta,
534
535
536
537
538
539
540
541
542
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          &heuristicResult[heuristic].algo,
                          // NULL,
                          workspace,
                          workspaceSize,
543
                          at::cuda::getCurrentCUDAStream());
544
545
546
547
548
549
550

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

551
template int gemm_dact_bgradb_lt(
552
553
    cublasOperation_t transa,
    cublasOperation_t transb,
554
555
556
557
558
559
560
561
562
563
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const void* pre_act,
    at::Half* C,
564
    int64_t ldc,
565
566
567
    at::Half* bgrad,
    bool is_gelu,
    int heuristic);
568

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
template int gemm_dact_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const void* pre_act,
    at::BFloat16* C,
    int64_t ldc,
    at::BFloat16* bgrad,
    bool is_gelu,
    int heuristic);
586
587
588
589

#endif

template <typename T>
590
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) {
591
592
593
594
595
    const float alpha          = 1.0;
    const float beta_zero      = 0.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
    status = gemm_bgradb_lt(
596
    // (cublasLtHandle_t)handle,
597
598
599
600
601
    CUBLAS_OP_N,
    CUBLAS_OP_T,
    in_features,
    out_features,
    batch_size,
602
    alpha,
603
604
605
606
607
608
    input,
    in_features,
    d_output,
    out_features,
    d_weight,
    in_features,
609
    d_bias);
610
611
612
#endif

    if (status != 0){
613
        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        status = gemm_bias(
          handle,
          CUBLAS_OP_N,
          CUBLAS_OP_T,
          in_features,
          out_features,
          batch_size,
          &alpha,
          input,
          in_features,
          d_output,
          out_features,
          &beta_zero,
          d_weight,
          in_features);
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error
        // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341
        // at::cuda::blas::gemm<T>(
        //   'N',
        //   'T',
        //   in_features,
        //   out_features,
        //   batch_size,
        //   alpha,
        //   input,
        //   in_features,
        //   d_output,
        //   out_features,
        //   beta_zero,
        //   d_weight,
        //   in_features);
645
646
647
648
649
650
    }

    return status;
}

template <typename T>
651
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) {
652
653
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
654
    status = gemm_bias_act_lt(
655
656
657
658
659
    CUBLAS_OP_T,
    CUBLAS_OP_N,
    out_features,
    batch_size,
    in_features,
660
    /*alpha=*/1.0,
661
662
663
664
    weight,
    in_features,
    input,
    in_features,
665
    bias,
666
667
    output,
    out_features,
668
669
670
    pre_act,
    is_gelu,
    heuristic);
671
672
673
674
675
676
677
    return status;
#else
    return 1;
#endif
}

template <typename T>
678
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) {
679
680
681
    const float alpha          = 1.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
682
    status = gemm_dact_bgradb_lt(
683
684
    CUBLAS_OP_N,
    CUBLAS_OP_N,
Tri Dao's avatar
Tri Dao committed
685
    in_features,
686
687
    batch_size,
    out_features,
688
    alpha,
Tri Dao's avatar
Tri Dao committed
689
690
691
    weight,
    in_features,
    d_output,
692
    out_features,
693
    pre_act,
Tri Dao's avatar
Tri Dao committed
694
695
    d_input,
    in_features,
696
697
698
    d_bias,
    is_gelu,
    heuristic);
699
700
701
702
703
#endif
    return status;

}

704
705
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias);
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias);
706

707
708
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act);
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act);
709

710
711
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias);
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias);