cublaslt_gemm.cu 47.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
17
18
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

19
#include <cstdint>
20
#include <mutex>
Tim Moon's avatar
Tim Moon committed
21

Przemek Tredak's avatar
Przemek Tredak committed
22
#include "../common.h"
23
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
24
#include "../util/logging.h"
25
#include "common/util/cuda_runtime.h"
Przemek Tredak's avatar
Przemek Tredak committed
26

yuguo's avatar
yuguo committed
27
#ifndef __HIP_PLATFORM_AMD__
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

48
49
50
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
51
  for (;; alignment /= 2) {
52
53
54
55
56
57
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

58
59
60
61
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

62
63
64
65
66
67
68
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
69
struct GemmParam {
70
71
72
73
74
75
76
77
78
79
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
80
81
};

82
83
84
85
86
87
88
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
89
90
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
91
                                int m, int n, int k) {
92
  using namespace transformer_engine;
93
94
95
96
97
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
      "Inputs A and B to GEMM need to have compatible scaling modes!");
98
99
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
100
  GemmParam ret;
101

102
  // Transpose mode with column-major ordering
103
104
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
105

106
  // Configure A matrix
107
  if (is_tensor_scaling(A.scaling_mode)) {
108
    // Unscaled or FP8 tensor scaling
109
    ret.A = A.data.dptr;
110
111
    ret.transA = transA;
    ret.Atype = A.data.dtype;
112
    ret.A_scale_inv = A.scale_inv.dptr;
113
    ret.lda = is_A_transposed ? k : m;
114
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
115
116
117
118
119
120
121
122
123
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
124
125
      }
    }
126
127
128
129
  } else if (is_mxfp_scaling(A.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
130
    if (is_A_transposed) {
131
132
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
133
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
134
    }
135
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
136
    ret.transA = transA;
137
138
139
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
140
141
142
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
143
    if (is_A_transposed) {
144
145
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
146
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
147
    }
148
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
149
    ret.transA = CUBLAS_OP_T;
150
151
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
               "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
168
    ret.B = B.data.dptr;
169
170
    ret.transB = transB;
    ret.Btype = B.data.dtype;
171
    ret.B_scale_inv = B.scale_inv.dptr;
172
    ret.ldb = is_B_transposed ? n : k;
173
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
174
175
176
177
178
179
180
181
182
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
183
      }
184
185
186
187
188
    }
  } else if (is_mxfp_scaling(B.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
189
    if (is_B_transposed) {
190
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
191
    } else {
192
193
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
194
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
195
    ret.transB = transB;
196
197
198
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
199
200
201
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
202
    if (is_B_transposed) {
203
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
204
    } else {
205
206
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
207
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
208
    ret.transB = CUBLAS_OP_N;
209
210
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
211
212
213
214
215
216
217
218
219
220
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
221
222
    }
  } else {
223
    NVTE_ERROR("B has unsupported scaling mode");
224
  }
225

226
227
228
  return ret;
}

229
}  // namespace
yuguo's avatar
yuguo committed
230
#endif // __HIP_PLATFORM_AMD__
231

Przemek Tredak's avatar
Przemek Tredak committed
232
namespace transformer_engine {
yuguo's avatar
yuguo committed
233
234
235
236
237
238
239
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
yuguo's avatar
yuguo committed
240
                 const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
yuguo's avatar
yuguo committed
241
#else // Use cublasLt
242
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
243
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
                 bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
                 int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

263
264
265
266
267
268
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

269
270
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

271
  void *C = outputD->data.dptr;
272
  void *D = outputD->data.dptr;
273
274
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
275
276
277
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
278
279
280
281
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
282
  const bool gelu = pre_gelu_out != nullptr;
283
284
285
286
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
287
288
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
289

290
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
291
             "FP8 input to GEMM requires inverse of scale!");
292
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
293
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
294

295
296
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
297
298
299
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
300
               "fp8 Aux output for gemm + gelu fusion not supported!");
301
  }
302
  if (is_fp8_dtype(outputD->data.dtype)) {
303
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
304
  }
Przemek Tredak's avatar
Przemek Tredak committed
305

306
307
308
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
309

310
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
311

312
313
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
314
  cublasLtMatmulPreference_t preference = nullptr;
315
  int returnedResults = 0;
316
317
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
318

319
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
320

321
322
323
324
325
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
326

327
  // Create matrix descriptors. Not setting any extra attributes.
328
329
330
331
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
332

333
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
334

335
336
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
337
                                                   &param.transA, sizeof(param.transA)));
338
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
339
                                                   &param.transB, sizeof(param.transB)));
340
341
  // Set math SM count
  if (math_sm_count != 0) {
342
343
344
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
345
346
  }

347
348
349
350
351
352
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
353
354
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
355
356
357

    // Scaling factors.
#if CUDA_VERSION >= 12080
358
359
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
360
#endif
361
    if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
362
363
364
365
366
367
368
369
370
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
371
372
373
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
    } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
374
375
376
377
378
379
380
381
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
382
383
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
384
385
386
387
388
389
390
391
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
      if (cublasLtGetVersion() <= 120803) {
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
407
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
408
409
410
411
412
413
414
415
416
417
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
      NVTE_ERROR("FP8 block scaling requires CUDA 12.9+");
#endif  // CUDA_VERSION >= 12090
#endif  // CUDA_VERSION >= 12080
418
419
420
421
422
423
424
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

#if CUDA_VERSION >= 12080
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
425
        operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
426
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
427
        operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
428
#endif
429
430
431
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
432
433
434
435
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
436
#if CUDA_VERSION >= 12080
437
438
439
440
441
      // NOTE: In all current cases where FP8 output is supported, the input is
      // scaled identically to the output.
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
442
443
444
445
446
#endif
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
447
448
449
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
450
    if (bias) {
451
452
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
453
    }
454
455
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
456
  }
Przemek Tredak's avatar
Przemek Tredak committed
457

458
459
460
461
462
463
464
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
465
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
466
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
467
468
469
470
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
471
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
472
473
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
474
475
476
477
478
479
480
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
481
482
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
483
484
485
486
487
488
489
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
490
491
492
493
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
494
495
496
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
497
  }
Przemek Tredak's avatar
Przemek Tredak committed
498

499
500
501
502
503
504
505
506
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

507
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
508
                                                   &epilogue, sizeof(epilogue)));
509

510
511
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
512
513
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
514
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
515
516
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
517
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
518
519
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
520
521
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
522
523
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
524
525
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
526
527
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
528
529
530
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
531

532
533
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
534
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
535
536
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
537
538
539
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
540
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
541
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
542
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
543
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
544
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
545
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
546
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
547

548
549
550
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
551
552
553
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
554
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
555

556
  // D = alpha * (A * B) + beta * C
557
558
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
559
560
                                   param.A,                                 /* A */
                                   Adesc, param.B,                          /* B */
561
562
563
564
565
566
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
567

568
  // Update FP8 scale-inv in output tensor
569
570
571
572
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
573
574
575
    update_tensor_scale_inv(outputD, stream);
  }

576
577
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
578
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
579
580
581
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
582
}
yuguo's avatar
yuguo committed
583
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
584

585
586
587
588
589
590
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
yuguo's avatar
yuguo committed
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
  int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
  unsigned int cuMask[4];
  unsigned int cuMaskSize = 4;
  if (comm_cu_nums == 4) {
    cuMask[0] = 0xfffffff0;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 8) {
    cuMask[0] = 0xffffff00;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 16) {
    cuMask[0] = 0xffff0000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 32) {
    cuMask[0] = 0x00000000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else {
    NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
  }
  const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
618
  for (int i = 0; i < num_streams; i++) {
yuguo's avatar
yuguo committed
619
620
621
622
623
624
625
#ifdef __HIP_PLATFORM_AMD__    
    if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
      NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
    } else {
      NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    }
#else
626
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
yuguo's avatar
yuguo committed
627
#endif
628
629
630
631
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

yuguo's avatar
yuguo committed
632
633
634
635
636
637
638
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
yuguo's avatar
yuguo committed
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
  int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
  unsigned int cuMask[4];
  unsigned int cuMaskSize = 4;
  if (comm_cu_nums == 4) {
    cuMask[0] = 0xfffffff0;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 8) {
    cuMask[0] = 0xffffff00;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 16) {
    cuMask[0] = 0xffff0000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 32) {
    cuMask[0] = 0x00000000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else {
    NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
  }
  const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
yuguo's avatar
yuguo committed
666
  for (int i = 0; i < num_batchgemm_streams; i++) {
yuguo's avatar
yuguo committed
667
668
669
670
671
672
673
#ifdef __HIP_PLATFORM_AMD__    
    if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
      NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams_batchgemm[i], cuMaskSize, cuMask));
    } else {
      NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    }
#else
yuguo's avatar
yuguo committed
674
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
yuguo's avatar
yuguo committed
675
#endif
yuguo's avatar
yuguo committed
676
677
678
679
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

680
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
681

682
683
684
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
685
                      int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
686
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
687
  using namespace transformer_engine;
688
689
690
691
692
693
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
694

695
#ifdef __HIP_PLATFORM_AMD__
696
697
698
699
700
701
702
703
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

721
722
723
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);

yuguo's avatar
yuguo committed
724
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
725
726
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
727
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
728
    NVTE_CHECK(!use_int8, "Int8 gemm just surpport pure int8 gemm without any epilogue."); 
729
730
731
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
732
  } else {
yuguo's avatar
yuguo committed
733
    hipblas_gemm(inputA,
yuguo's avatar
yuguo committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
yuguo's avatar
yuguo committed
751
  }
752
#else 
753
754
755
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
              accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
756
#endif  //__HIP_PLATFORM_AMD__
757
758
}

759
760
761
762
763
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
yuguo's avatar
yuguo committed
764
                             cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
765
766
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

yuguo's avatar
yuguo committed
767
#ifndef __HIP_PLATFORM_AMD__
768
769
770
771
  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
yuguo's avatar
yuguo committed
772
#endif
773
774

  using namespace transformer_engine;
775
776
777
778
779
780
781
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
782

783
784
785
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
786
#ifdef __HIP_PLATFORM_AMD__
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
807
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
808
809
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
810
811
812
813
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 
                m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 m_split,
                 n_split,
                 gemm_producer,
                 inputCounter,
                 stream);
yuguo's avatar
yuguo committed
833
  }
834
835
#else 
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
836
837
838
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
              accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
              inputCounter, stream);
839
#endif  //__HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
840
841
842
}


843

844
845
846
847
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
848
849
850
851
852
853
854
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

855
  int num_stream_used = std::min(num_streams, num_gemms);
856
857
858
859
860
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }
yuguo's avatar
yuguo committed
861
  const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
yuguo's avatar
yuguo committed
862
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
863
864
865
866
867
  bool NVTE_FORCE_BLAS_MULSTREAM;
  if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
    NVTE_FORCE_BLAS_MULSTREAM = true;
    if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'))
      NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
yuguo's avatar
yuguo committed
868
  } else{
yuguo's avatar
yuguo committed
869
    NVTE_FORCE_BLAS_MULSTREAM = false;
yuguo's avatar
yuguo committed
870
  }
yuguo's avatar
yuguo committed
871
  if (NVTE_FORCE_BLAS_MULSTREAM){
yuguo's avatar
yuguo committed
872
    for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
873
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
yuguo's avatar
yuguo committed
874
875
876
877
878
879
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
880
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
yuguo's avatar
yuguo committed
881
                     compute_streams[i % num_streams], 1, 0, i % num_streams);
yuguo's avatar
yuguo committed
882
    }
883
884
885
886
887
888
889
890
891
892
893
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}
yuguo's avatar
yuguo committed
894

895
#ifndef __HIP_PLATFORM_AMD__
896
897
898
899
900
901
902
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
903
904
#endif

yuguo's avatar
yuguo committed
905
#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
906

yuguo's avatar
yuguo committed
907
908
909
910
911
912
913
914
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
915
  int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);
yuguo's avatar
yuguo committed
916
917
918
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

yuguo's avatar
yuguo committed
919
  int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
yuguo's avatar
yuguo committed
920
921
922
923
924
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
yuguo's avatar
yuguo committed
925
  for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
yuguo's avatar
yuguo committed
953
954
955
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }
yuguo's avatar
yuguo committed
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
yuguo's avatar
yuguo committed
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
  hipblas_batchgemm(inputA,
            inputB,
            outputD,
            biasTensor,
            outputGelu,
            m, n, k,
            lda, ldb, ldd,
            (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            grad, wspace->data.dptr,
            wspace->data.shape[0],
            accumulate, use_split_accumulator,
            math_sm_count,
            0,
            0,
            false,
            nullptr,
            batch_count,
            stream);
}

// add for batchgemm
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm_v2);
  using namespace transformer_engine;
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
  hipblas_batchgemm(inputA,
            inputB,
            outputD,
            biasTensor,
            outputGelu,
            m, n, k,
            lda, ldb, ldd,
            (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            grad, wspace->data.dptr,
            wspace->data.shape[0],
            accumulate, use_split_accumulator,
            math_sm_count,
            0,
            0,
            false,
            nullptr,
            batch_count,
            stream);
yuguo's avatar
yuguo committed
1078
1079
}
#endif