fused_dense_cuda.cu 24.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#include <cublasLt.h>
#endif

// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
23
24
25
    int64_t m,
    int64_t n,
    int64_t k,
26
    const float* alpha,
27
28
29
30
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
31
32
    const float* beta,
    at::Half* C,
33
    int64_t ldc) {
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16F,
      lda,
      B,
      CUDA_R_16F,
      ldb,
      beta,
      C,
      CUDA_R_16F,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
    cublasHandle_t handle,
    cublasOperation_t transa,
    cublasOperation_t transb,
61
62
63
    int64_t m,
    int64_t n,
    int64_t k,
64
    const float* alpha,
65
66
67
68
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
69
70
    const float* beta,
    at::BFloat16* C,
71
    int64_t ldc) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  return cublasGemmEx(
      handle,
      transa,
      transb,
      m,
      n,
      k,
      alpha,
      A,
      CUDA_R_16BF,
      lda,
      B,
      CUDA_R_16BF,
      ldb,
      beta,
      C,
      CUDA_R_16BF,
      ldc,
      CUDA_R_32F,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600

96
97
template <typename Dtype>
int gemm_bias_act_lt(
98
99
    cublasOperation_t transa,
    cublasOperation_t transb,
100
101
102
103
104
105
106
107
108
109
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const Dtype* bias,
    Dtype* C,
110
    int64_t ldc,
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    void* pre_act,
    bool is_gelu,
    int heuristic
    ) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bias_act_lt only supports fp16 and bf16");
  bool save_pre_act = pre_act != nullptr;
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
125
126
127
  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
128
129
130
131
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

132
133
134
135
136
137
138
139
140
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
141
142
143
144
145
  // constexpr int requestedAlgoCount = 1;
  // cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = is_gelu
      ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
      : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
146
147
148
149
150
151
152
153
154
155
156

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

157
158
  if (save_pre_act) {
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
159
160
161
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  }

Tri Dao's avatar
Tri Dao committed
162
  if (bias != nullptr) {
163
164
165
166
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
167
168
169
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);
Tri Dao's avatar
Tri Dao committed
170
  } else {
171
172
173
    epilogue = is_gelu
        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
174
175
176
177
178
179
180
181
182
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
183
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
184
185
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
186
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
187
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
188
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
208
    // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
209
210
211
212
213
214
215
216
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
217
                          &alpha,
218
219
220
221
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
222
                          &beta,
223
224
225
226
227
228
229
230
231
232
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          // &heuristicResult.algo,
                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
                          &heuristicResult[heuristic].algo,
                          // NULL,
                          workspace,
                          workspaceSize,
233
                          at::cuda::getCurrentCUDAStream());
234
235
236
237
238
239
240

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

241
template int gemm_bias_act_lt(
242
243
    cublasOperation_t transa,
    cublasOperation_t transb,
244
245
246
247
248
249
250
251
252
253
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const at::Half* bias,
    at::Half* C,
254
    int64_t ldc,
255
256
257
    void* pre_act,
    bool is_gelu,
    int heuristic);
258

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
template int gemm_bias_act_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const at::BFloat16* bias,
    at::BFloat16* C,
    int64_t ldc,
    void* pre_act,
    bool is_gelu,
    int heuristic);
276

277
template <typename Dtype>
278
279
280
int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    Dtype* C,
    int64_t ldc,
    Dtype* bgrad) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
301
302
  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
303
304
305
306
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

Tri Dao's avatar
Tri Dao committed
327
  if (bgrad != nullptr) {
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
    if (status != CUBLAS_STATUS_SUCCESS) {
      goto CLEANUP;
    }
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
  }

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
342
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
343
344
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
345
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
346
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
347
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
375
                          &alpha,
376
377
378
379
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
380
                          &beta,
381
382
383
384
385
386
387
388
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          NULL,
                          workspace,
                          workspaceSize,
389
                          at::cuda::getCurrentCUDAStream());
390
391
392
393
394
395
396

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

397
398

template int gemm_bgradb_lt(
399
400
    cublasOperation_t transa,
    cublasOperation_t transb,
401
402
403
404
405
406
407
408
409
410
411
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    at::Half* C,
    int64_t ldc,
    at::Half* bgrad);
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
template int gemm_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    at::BFloat16* C,
    int64_t ldc,
    at::BFloat16* bgrad);
427

428
429
template <typename Dtype>
int gemm_dact_bgradb_lt(
430
431
    cublasOperation_t transa,
    cublasOperation_t transb,
432
433
434
435
436
437
438
439
440
441
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const Dtype* A,
    int64_t lda,
    const Dtype* B,
    int64_t ldb,
    const void* pre_act,
    Dtype* C,
442
    int64_t ldc,
443
444
445
446
447
448
449
450
451
452
453
454
    Dtype* bgrad,
    bool is_gelu,
    int heuristic) {
  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                "gemm_dact_bgradb_lt only supports fp16 and bf16");
  float beta = 0.0;
  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;

  cublasLtHandle_t ltHandle =
    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
455
456
  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
457
458
459
460
  void* workspace = at::empty(
    {static_cast<int64_t>(workspaceSize)},
    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

461
462
463
464
465
466
467
468
469
  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

  cublasLtMatmulDescOpaque_t operationDesc = {};
  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  cublasLtMatmulPreferenceOpaque_t preference = {};

  int returnedResults                             = 0;
  constexpr int requestedAlgoCount = 5;
  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
470
  cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485

  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  // for details about defaults; here we just set the transforms for
  // A and B.
  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
486
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
487
488
489
490
491
492
493
494
495
496
497
498
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }
  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));

  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  if (status != CUBLAS_STATUS_SUCCESS) {
    goto CLEANUP;
  }

  // Create matrix descriptors. Not setting any extra attributes.
  status = cublasLtMatrixLayoutInit(
499
    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
500
501
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatrixLayoutInit(
502
    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
503
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
504
  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // Create preference handle; In general, extra attributes can be
  // used here to disable tensor ops or to make sure algo selected
  // will work with badly aligned A, B, C. However, for simplicity
  // here we assume A,B,C are always well aligned (e.g., directly
  // come from cudaMalloc)
  status = cublasLtMatmulPreferenceInit(&preference);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  status = cublasLtMatmulPreferenceSetAttribute(
    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  // We just need the best available heuristic to try and run matmul.
  // There is no guarantee that this will work. For example, if A is
  // badly aligned, you can request more (e.g. 32) algos and try to
  // run them one by one until something works.
  status = cublasLtMatmulAlgoGetHeuristic(
    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

  if (returnedResults == 0) {
    status = CUBLAS_STATUS_NOT_SUPPORTED;
    goto CLEANUP;
  }
  status = cublasLtMatmul(ltHandle,
                          &operationDesc,
532
                          &alpha,
533
534
535
536
                          A,
                          &Adesc,
                          B,
                          &Bdesc,
537
                          &beta,
538
539
540
541
542
543
544
545
546
                          C,
                          &Cdesc,
                          C,
                          &Cdesc,
                          //&heuristicResult.algo,
                          &heuristicResult[heuristic].algo,
                          // NULL,
                          workspace,
                          workspaceSize,
547
                          at::cuda::getCurrentCUDAStream());
548
549
550
551
552
553
554

CLEANUP:
  // Descriptors are no longer needed as all GPU work was already
  // enqueued.
  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}

555
template int gemm_dact_bgradb_lt(
556
557
    cublasOperation_t transa,
    cublasOperation_t transb,
558
559
560
561
562
563
564
565
566
567
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::Half* A,
    int64_t lda,
    const at::Half* B,
    int64_t ldb,
    const void* pre_act,
    at::Half* C,
568
    int64_t ldc,
569
570
571
    at::Half* bgrad,
    bool is_gelu,
    int heuristic);
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
template int gemm_dact_bgradb_lt(
    cublasOperation_t transa,
    cublasOperation_t transb,
    int64_t m,
    int64_t n,
    int64_t k,
    float alpha,
    const at::BFloat16* A,
    int64_t lda,
    const at::BFloat16* B,
    int64_t ldb,
    const void* pre_act,
    at::BFloat16* C,
    int64_t ldc,
    at::BFloat16* bgrad,
    bool is_gelu,
    int heuristic);
590
591
592
593

#endif

template <typename T>
594
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) {
595
596
597
598
599
    const float alpha          = 1.0;
    const float beta_zero      = 0.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
    status = gemm_bgradb_lt(
600
    // (cublasLtHandle_t)handle,
601
602
603
604
605
    CUBLAS_OP_N,
    CUBLAS_OP_T,
    in_features,
    out_features,
    batch_size,
606
    alpha,
607
608
609
610
611
612
    input,
    in_features,
    d_output,
    out_features,
    d_weight,
    in_features,
613
    d_bias);
614
615
616
#endif

    if (status != 0){
617
        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        status = gemm_bias(
          handle,
          CUBLAS_OP_N,
          CUBLAS_OP_T,
          in_features,
          out_features,
          batch_size,
          &alpha,
          input,
          in_features,
          d_output,
          out_features,
          &beta_zero,
          d_weight,
          in_features);
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error
        // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341
        // at::cuda::blas::gemm<T>(
        //   'N',
        //   'T',
        //   in_features,
        //   out_features,
        //   batch_size,
        //   alpha,
        //   input,
        //   in_features,
        //   d_output,
        //   out_features,
        //   beta_zero,
        //   d_weight,
        //   in_features);
649
650
651
652
653
654
    }

    return status;
}

template <typename T>
655
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) {
656
657
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
658
    status = gemm_bias_act_lt(
659
660
661
662
663
    CUBLAS_OP_T,
    CUBLAS_OP_N,
    out_features,
    batch_size,
    in_features,
664
    /*alpha=*/1.0,
665
666
667
668
    weight,
    in_features,
    input,
    in_features,
669
    bias,
670
671
    output,
    out_features,
672
673
674
    pre_act,
    is_gelu,
    heuristic);
675
676
677
678
679
680
681
    return status;
#else
    return 1;
#endif
}

template <typename T>
682
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) {
683
684
685
    const float alpha          = 1.0;
    int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
686
    status = gemm_dact_bgradb_lt(
687
688
    CUBLAS_OP_N,
    CUBLAS_OP_N,
Tri Dao's avatar
Tri Dao committed
689
    in_features,
690
691
    batch_size,
    out_features,
692
    alpha,
Tri Dao's avatar
Tri Dao committed
693
694
695
    weight,
    in_features,
    d_output,
696
    out_features,
697
    pre_act,
Tri Dao's avatar
Tri Dao committed
698
699
    d_input,
    in_features,
700
701
702
    d_bias,
    is_gelu,
    heuristic);
703
704
705
706
707
#endif
    return status;

}

708
709
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias);
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias);
710

711
712
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act);
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act);
713

714
715
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias);
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias);