cublaslt_gemm.cu 42.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
17
18
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

19
#include <cstdint>
20
#include <mutex>
Tim Moon's avatar
Tim Moon committed
21

Przemek Tredak's avatar
Przemek Tredak committed
22
#include "../common.h"
23
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
24
#include "../util/logging.h"
yuguo's avatar
yuguo committed
25
26
27
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
28
#include "common/util/cuda_runtime.h"
yuguo's avatar
yuguo committed
29
#endif
Przemek Tredak's avatar
Przemek Tredak committed
30

yuguo's avatar
yuguo committed
31
#ifndef __HIP_PLATFORM_AMD__
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

52
53
54
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
55
  for (;; alignment /= 2) {
56
57
58
59
60
61
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

62
63
64
65
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

66
67
68
69
70
71
72
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
73
struct GemmParam {
74
75
76
77
78
79
80
81
82
83
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
84
85
};

86
87
88
89
90
91
92
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
93
94
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
95
                                int m, int n, int k) {
96
  using namespace transformer_engine;
97
98
99
100
101
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
      "Inputs A and B to GEMM need to have compatible scaling modes!");
102
103
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
104
  GemmParam ret;
105

106
  // Transpose mode with column-major ordering
107
108
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
109

110
  // Configure A matrix
111
  if (is_tensor_scaling(A.scaling_mode)) {
112
    // Unscaled or FP8 tensor scaling
113
    ret.A = A.data.dptr;
114
115
    ret.transA = transA;
    ret.Atype = A.data.dtype;
116
    ret.A_scale_inv = A.scale_inv.dptr;
117
    ret.lda = is_A_transposed ? k : m;
118
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
119
120
121
122
123
124
125
126
127
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
128
129
      }
    }
130
131
132
133
  } else if (is_mxfp_scaling(A.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
134
    if (is_A_transposed) {
135
136
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
137
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
138
    }
139
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
140
    ret.transA = transA;
141
142
143
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
144
145
146
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
147
    if (is_A_transposed) {
148
149
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
150
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
151
    }
152
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
153
    ret.transA = CUBLAS_OP_T;
154
155
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
               "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
172
    ret.B = B.data.dptr;
173
174
    ret.transB = transB;
    ret.Btype = B.data.dtype;
175
    ret.B_scale_inv = B.scale_inv.dptr;
176
    ret.ldb = is_B_transposed ? n : k;
177
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
178
179
180
181
182
183
184
185
186
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
187
      }
188
189
190
191
192
    }
  } else if (is_mxfp_scaling(B.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
193
    if (is_B_transposed) {
194
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
195
    } else {
196
197
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
198
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
199
    ret.transB = transB;
200
201
202
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
203
204
205
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
206
    if (is_B_transposed) {
207
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
208
    } else {
209
210
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
211
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
212
    ret.transB = CUBLAS_OP_N;
213
214
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
215
216
217
218
219
220
221
222
223
224
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
225
226
    }
  } else {
227
    NVTE_ERROR("B has unsupported scaling mode");
228
  }
229

230
231
232
  return ret;
}

233
}  // namespace
yuguo's avatar
yuguo committed
234
#endif // __HIP_PLATFORM_AMD__
235

Przemek Tredak's avatar
Przemek Tredak committed
236
namespace transformer_engine {
yuguo's avatar
yuguo committed
237
238
239
240
241
242
243
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
yuguo's avatar
yuguo committed
244
                 const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
yuguo's avatar
yuguo committed
245
#else // Use cublasLt
246
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
247
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
                 bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
                 int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

267
268
269
270
271
272
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

273
274
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

275
  void *C = outputD->data.dptr;
276
  void *D = outputD->data.dptr;
277
278
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
279
280
281
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
282
283
284
285
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
286
  const bool gelu = pre_gelu_out != nullptr;
287
288
289
290
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
291
292
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
293

294
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
295
             "FP8 input to GEMM requires inverse of scale!");
296
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
297
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
298

299
300
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
301
302
303
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
304
               "fp8 Aux output for gemm + gelu fusion not supported!");
305
  }
306
  if (is_fp8_dtype(outputD->data.dtype)) {
307
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
308
  }
Przemek Tredak's avatar
Przemek Tredak committed
309

310
311
312
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
313

314
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
315

316
317
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
318
  cublasLtMatmulPreference_t preference = nullptr;
319
  int returnedResults = 0;
320
321
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
322

323
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
324

325
326
327
328
329
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
330

331
  // Create matrix descriptors. Not setting any extra attributes.
332
333
334
335
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
336

337
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
338

339
340
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
341
                                                   &param.transA, sizeof(param.transA)));
342
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
343
                                                   &param.transB, sizeof(param.transB)));
344
345
  // Set math SM count
  if (math_sm_count != 0) {
346
347
348
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
349
350
  }

351
352
353
354
355
356
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
357
358
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
359
360
361

    // Scaling factors.
#if CUDA_VERSION >= 12080
362
363
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
364
#endif
365
    if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
366
367
368
369
370
371
372
373
374
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
375
376
377
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
    } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
378
379
380
381
382
383
384
385
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
386
387
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
388
389
390
391
392
393
394
395
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
      if (cublasLtGetVersion() <= 120803) {
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
411
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
412
413
414
415
416
417
418
419
420
421
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
      NVTE_ERROR("FP8 block scaling requires CUDA 12.9+");
#endif  // CUDA_VERSION >= 12090
#endif  // CUDA_VERSION >= 12080
422
423
424
425
426
427
428
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

#if CUDA_VERSION >= 12080
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
429
        operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
430
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
431
        operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
432
#endif
433
434
435
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
436
437
438
439
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
440
#if CUDA_VERSION >= 12080
441
442
443
444
445
      // NOTE: In all current cases where FP8 output is supported, the input is
      // scaled identically to the output.
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
446
447
448
449
450
#endif
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
451
452
453
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
454
    if (bias) {
455
456
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
457
    }
458
459
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
460
  }
Przemek Tredak's avatar
Przemek Tredak committed
461

462
463
464
465
466
467
468
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
469
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
470
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
471
472
473
474
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
475
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
476
477
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
478
479
480
481
482
483
484
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
485
486
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
487
488
489
490
491
492
493
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
494
495
496
497
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
498
499
500
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
501
  }
Przemek Tredak's avatar
Przemek Tredak committed
502

503
504
505
506
507
508
509
510
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

511
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
512
                                                   &epilogue, sizeof(epilogue)));
513

514
515
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
516
517
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
518
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
519
520
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
521
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
522
523
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
524
525
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
526
527
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
528
529
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
530
531
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
532
533
534
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
535

536
537
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
538
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
539
540
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
541
542
543
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
544
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
545
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
546
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
547
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
548
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
549
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
550
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
551

552
553
554
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
555
556
557
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
558
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
559

560
  // D = alpha * (A * B) + beta * C
561
562
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
563
564
                                   param.A,                                 /* A */
                                   Adesc, param.B,                          /* B */
565
566
567
568
569
570
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
571

572
  // Update FP8 scale-inv in output tensor
573
574
575
576
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
577
578
579
    update_tensor_scale_inv(outputD, stream);
  }

580
581
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
582
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
583
584
585
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
586
}
yuguo's avatar
yuguo committed
587
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
588

589
590
591
592
593
594
595
596
597
598
599
600
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
  for (int i = 0; i < num_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

yuguo's avatar
yuguo committed
601
602
603
604
605
606
607
608
609
610
611
612
613
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
  for (int i = 0; i < num_batchgemm_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

614
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
615

616
617
618
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
619
                      int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
620
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
621
  using namespace transformer_engine;
622
623
624
625
626
627
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
628

629
#ifdef __HIP_PLATFORM_AMD__
630
631
632
633
634
635
636
637
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
655
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
656
657
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
658
659
660
661
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
662
  } else {
yuguo's avatar
yuguo committed
663
    hipblas_gemm(inputA,
yuguo's avatar
yuguo committed
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
yuguo's avatar
yuguo committed
681
  }
682
#else 
683
684
685
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
              accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
686
#endif  //__HIP_PLATFORM_AMD__
687
688
}

689
690
691
692
693
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
yuguo's avatar
yuguo committed
694
                             cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
695
696
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

yuguo's avatar
yuguo committed
697
#ifndef __HIP_PLATFORM_AMD__
698
699
700
701
  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
yuguo's avatar
yuguo committed
702
#endif
703
704

  using namespace transformer_engine;
705
706
707
708
709
710
711
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
712

713
714
715
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
716
#ifdef __HIP_PLATFORM_AMD__
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
737
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
738
739
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
740
741
742
743
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 
                m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 m_split,
                 n_split,
                 gemm_producer,
                 inputCounter,
                 stream);
yuguo's avatar
yuguo committed
763
  }
764
765
#else 
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
766
767
768
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
              accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
              inputCounter, stream);
769
#endif  //__HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
770
771
772
}


773

774
775
776
777
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
778
779
780
781
782
783
784
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

785
  int num_stream_used = std::min(num_streams, num_gemms);
786
787
788
789
790
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }
yuguo's avatar
yuguo committed
791
792
793
794
795
796
797
  const char *NVTE_HIPBLAS_MULSTREAM = std::getenv("NVTE_FORCE_HIPBLAS_MULSTREAM");
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
  bool NVTE_FORCE_HIPBLAS_MULSTREAM;
  if(NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'){
    NVTE_FORCE_HIPBLAS_MULSTREAM = true;
    if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'))
      NVTE_ERROR("NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
yuguo's avatar
yuguo committed
798
  } else{
yuguo's avatar
yuguo committed
799
    NVTE_FORCE_HIPBLAS_MULSTREAM = false;
yuguo's avatar
yuguo committed
800
  }
yuguo's avatar
yuguo committed
801
  if (NVTE_FORCE_HIPBLAS_MULSTREAM){
yuguo's avatar
yuguo committed
802
    for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
803
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
yuguo's avatar
yuguo committed
804
805
806
807
808
809
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
810
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
yuguo's avatar
yuguo committed
811
                     compute_streams[i % num_streams], 1, 0, i % num_streams);
yuguo's avatar
yuguo committed
812
    }
813
814
815
816
817
818
819
820
821
822
823
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}
yuguo's avatar
yuguo committed
824

825
#ifndef __HIP_PLATFORM_AMD__
826
827
828
829
830
831
832
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
833
834
#endif

yuguo's avatar
yuguo committed
835
#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
836

yuguo's avatar
yuguo committed
837
838
839
840
841
842
843
844
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
845
  int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);;
yuguo's avatar
yuguo committed
846
847
848
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

yuguo's avatar
yuguo committed
849
  int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
yuguo's avatar
yuguo committed
850
851
852
853
854
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
yuguo's avatar
yuguo committed
855
  for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
yuguo's avatar
yuguo committed
883
884
885
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }
yuguo's avatar
yuguo committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
    hipblas_batchgemm(inputA,
              inputB,
              outputD,
              biasTensor,
              outputGelu,
              m, n, k,
              lda, ldb, ldd,
              (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              grad, wspace->data.dptr,
              wspace->data.shape[0],
              accumulate, use_split_accumulator,
              math_sm_count,
              0,
              0,
              false,
              nullptr,
              batch_count,
              stream);
}
#endif