cublaslt_gemm.cu 18.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
9
#include <cuda.h>
10
11
12
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

13
#include <cstdint>
14
#include <mutex>
Tim Moon's avatar
Tim Moon committed
15

Przemek Tredak's avatar
Przemek Tredak committed
16
#include "../common.h"
Tim Moon's avatar
Tim Moon committed
17
#include "../util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

39
40
41
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
42
  for (;; alignment /= 2) {
43
44
45
46
47
48
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

49
50
}  // namespace

Przemek Tredak's avatar
Przemek Tredak committed
51
52
namespace transformer_engine {

53
54
55
56
57
58
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
                 void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, cudaStream_t stream) {
59
60
61
62
  void *A = inputA->data.dptr;
  void *A_scale_inverse = inputA->scale_inv.dptr;
  void *B = inputB->data.dptr;
  void *B_scale_inverse = inputB->scale_inv.dptr;
63
  void *C = outputD->data.dptr;
64
  void *D = outputD->data.dptr;
65
66
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
67
68
69
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
70
71
72
73
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
74
  const bool gelu = pre_gelu_out != nullptr;
75
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
76
77
78
79
  const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype);
  const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype);
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
80

81
82
83
84
  NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
85

86
87
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
88
89
90
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
91
               "fp8 Aux output for gemm + gelu fusion not supported!");
92
  }
93
  if (is_fp8_dtype(outputD->data.dtype)) {
94
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
95
  }
Przemek Tredak's avatar
Przemek Tredak committed
96

97
98
99
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
100

101
102
  cublasLtHandle_t handle;
  NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
Przemek Tredak's avatar
Przemek Tredak committed
103

104
105
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
106
  cublasLtMatmulPreference_t preference = nullptr;
107
  int returnedResults = 0;
108
109
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
110

111
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
112

113
114
115
116
117
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
118

119
  // Create matrix descriptors. Not setting any extra attributes.
120
121
122
123
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k,
                                               transa == CUBLAS_OP_N ? k : m, lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n,
                                               transb == CUBLAS_OP_N ? n : k, ldb));
124
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
125

126
127
128
129
130
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
                                                   &transa, sizeof(transa)));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
                                                   &transb, sizeof(transb)));
131
132
  // Set math SM count
  if (math_sm_count != 0) {
133
134
135
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
136
137
  }

138
139
140
141
142
143
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
144
145
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
146
147
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
148
                                                     &A_scale_inverse, sizeof(A_scale_inverse)));
149
150
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
151
                                                     &B_scale_inverse, sizeof(B_scale_inverse)));
152
153
154
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
155
156
157
158
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
159
160
161
162
163
      // For FP8 output, cuBLAS requires C_type to be same as bias_type
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd));
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
164
    if (bias) {
165
166
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
167
    }
168
169
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
170
  }
Przemek Tredak's avatar
Przemek Tredak committed
171

172
173
174
175
176
177
178
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
179
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
180
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
181
182
183
184
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
185
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
186
187
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
188
189
190
191
192
193
194
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
195
196
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
197
198
199
200
201
202
203
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
204
205
206
207
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
208
  }
Przemek Tredak's avatar
Przemek Tredak committed
209

210
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
211
                                                   &epilogue, sizeof(epilogue)));
212

213
214
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
215
216
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
217
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
218
219
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
220
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
221
222
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
223
224
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
225
226
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
227
228
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
229
230
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
231
232
233
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
234

235
236
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
237
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
238
239
240
241
242
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B));
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
243
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
244
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
245
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
246
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
247
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
248
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
249
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
250

251
252
253
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
254
255
256
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
Przemek Tredak's avatar
Przemek Tredak committed
257

258
  if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
259

260
  // D = alpha * (A * B) + beta * C
261
262
263
264
265
266
267
268
269
270
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
                                   A,                                       /* A */
                                   Adesc, B,                                /* B */
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
271

272
273
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
274
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
275
276
277
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
278
279
}

280
281
282
283
284
285
286
287
288
289
290
291
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
  for (int i = 0; i < num_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

292
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
293

294
295
296
297
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
298
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
299
  using namespace transformer_engine;
300
301
302
303
304
305
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
306

307
308
309
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
Przemek Tredak's avatar
Przemek Tredak committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

327
328
329
330
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, 0, 0, false, nullptr, stream);
331
332
}

333
334
335
336
337
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
338
339
340
341
342
343
344
345
346
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");

  using namespace transformer_engine;
347
348
349
350
351
352
353
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

375
376
377
378
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
Przemek Tredak's avatar
Przemek Tredak committed
379
}
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
                                   std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
                                   std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
                                   bool grad, std::vector<NVTETensor> workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

  int num_stream_used = std::min(num_streams, static_cast<int>(A.size()));
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }

  for (size_t i = 0; i < A.size(); i++) {
    nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}