cublaslt_gemm.cu 18.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
9
#include <cuda.h>
10
11
12
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

13
#include <cstdint>
14
#include <mutex>
Tim Moon's avatar
Tim Moon committed
15

Przemek Tredak's avatar
Przemek Tredak committed
16
#include "../common.h"
Tim Moon's avatar
Tim Moon committed
17
#include "../util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

39
40
41
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
42
  for (;; alignment /= 2) {
43
44
45
46
47
48
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

49
50
}  // namespace

Przemek Tredak's avatar
Przemek Tredak committed
51
52
namespace transformer_engine {

53
54
55
56
57
58
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
                 void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, cudaStream_t stream) {
59
60
61
62
  void *A = inputA->data.dptr;
  void *A_scale_inverse = inputA->scale_inv.dptr;
  void *B = inputB->data.dptr;
  void *B_scale_inverse = inputB->scale_inv.dptr;
63
  void *C = outputD->data.dptr;
64
  void *D = outputD->data.dptr;
65
66
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
67
68
69
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
70
71
72
73
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
74
  const bool gelu = pre_gelu_out != nullptr;
75
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
76
77
78
79
  const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype);
  const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype);
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
80

81
82
83
84
  NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
85

86
87
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
88
89
90
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
91
               "fp8 Aux output for gemm + gelu fusion not supported!");
92
  }
93
  if (is_fp8_dtype(outputD->data.dtype)) {
94
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
95
  }
Przemek Tredak's avatar
Przemek Tredak committed
96

97
98
99
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
100

101
102
  cublasLtHandle_t handle;
  NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
Przemek Tredak's avatar
Przemek Tredak committed
103

104
105
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
106
  cublasLtMatmulPreference_t preference = nullptr;
107
  int returnedResults = 0;
108
109
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
110

111
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
112

113
114
115
116
117
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
118

119
  // Create matrix descriptors. Not setting any extra attributes.
120
121
122
123
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k,
                                               transa == CUBLAS_OP_N ? k : m, lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n,
                                               transb == CUBLAS_OP_N ? n : k, ldb));
124
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
125

126
127
128
129
130
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
                                                   &transa, sizeof(transa)));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
                                                   &transb, sizeof(transb)));
131
132
  // Set math SM count
  if (math_sm_count != 0) {
133
134
135
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
136
137
  }

138
139
140
141
142
143
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
144
145
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
146
147
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
148
                                                     &A_scale_inverse, sizeof(A_scale_inverse)));
149
150
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
151
                                                     &B_scale_inverse, sizeof(B_scale_inverse)));
152
153
154
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
155
156
157
158
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
159
160
161
162
163
      // For FP8 output, cuBLAS requires C_type to be same as bias_type
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd));
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
164
    if (bias) {
165
166
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
167
    }
168
169
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
170
  }
Przemek Tredak's avatar
Przemek Tredak committed
171

172
173
174
175
176
177
178
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
179
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
180
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
181
182
183
184
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
185
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
186
187
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
188
189
190
191
192
193
194
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
195
196
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
197
198
199
200
201
202
203
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
204
205
206
207
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
208
  }
Przemek Tredak's avatar
Przemek Tredak committed
209

210
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
211
                                                   &epilogue, sizeof(epilogue)));
212

213
214
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
215
216
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
217
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
218
219
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
220
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
221
222
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
223
224
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
225
226
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
227
228
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
229
230
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
231
232
233
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
234

235
236
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
237
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
238
239
240
241
242
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B));
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
243
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
244
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
245
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
246
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
247
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
248
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
249
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
250

251
252
253
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
254
255
256
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
Przemek Tredak's avatar
Przemek Tredak committed
257

258
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
259

260
  // D = alpha * (A * B) + beta * C
261
262
263
264
265
266
267
268
269
270
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
                                   A,                                       /* A */
                                   Adesc, B,                                /* B */
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
271

272
273
274
275
276
  // Update FP8 scale-inv in output tensor
  if (is_fp8_dtype(outputD->data.dtype)) {
    update_tensor_scale_inv(outputD, stream);
  }

277
278
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
279
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
280
281
282
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
283
284
}

285
286
287
288
289
290
291
292
293
294
295
296
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
  for (int i = 0; i < num_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

297
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
298

299
300
301
302
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
303
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
304
  using namespace transformer_engine;
305
306
307
308
309
310
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
311

312
313
314
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
Przemek Tredak's avatar
Przemek Tredak committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

332
333
334
335
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, 0, 0, false, nullptr, stream);
336
337
}

338
339
340
341
342
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
343
344
345
346
347
348
349
350
351
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");

  using namespace transformer_engine;
352
353
354
355
356
357
358
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

380
381
382
383
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
Przemek Tredak's avatar
Przemek Tredak committed
384
}
385

386
387
388
389
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
390
391
392
393
394
395
396
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

397
  int num_stream_used = std::min(num_streams, num_gemms);
398
399
400
401
402
403
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }

404
  for (int i = 0; i < num_gemms; i++) {
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}