cublaslt_gemm.cu 50.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
9
#include <cuda.h>
10
#include <transformer_engine/gemm.h>
11
#include <transformer_engine/multi_stream.h>
12
#include <transformer_engine/recipe.h>
13
14
#include <transformer_engine/transformer_engine.h>

15
#include <algorithm>
16
#include <cstdint>
17
#include <mutex>
18
#include <vector>
Tim Moon's avatar
Tim Moon committed
19

Przemek Tredak's avatar
Przemek Tredak committed
20
#include "../common.h"
21
#include "../util/cuda_runtime.h"
22
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
23
#include "../util/logging.h"
24
#include "../util/multi_stream.h"
25
26
#include "./config.h"
#include "./cutlass_grouped_gemm.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
namespace {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;

inline float *GetScalarOne() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float one = 1.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
  return dev_ptr;
}

inline float *GetScalarZero() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float zero = 0.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
  return dev_ptr;
}

__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }

61
62
63
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
64
  for (;; alignment /= 2) {
65
66
67
68
69
70
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

71
72
73
74
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

75
76
77
78
79
80
81
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
82
struct GemmParam {
83
84
85
86
87
88
89
90
91
92
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
93
94
};

95
96
97
98
99
100
101
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
102
103
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
104
                                int m, int n, int k) {
105
  using namespace transformer_engine;
106
107
108
109
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
110
111
      "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
          to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
112
113
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
114
115
116
  GemmParam ret;

  // Transpose mode with column-major ordering
117
118
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
119

120
121
122
123
  // Set conditions for MXFP8 and NVFP4 gemm execution.
  const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
  const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);

124
  // Configure A matrix
125
  if (is_tensor_scaling(A.scaling_mode)) {
126
    // Unscaled or FP8 tensor scaling
127
    ret.A = A.data.dptr;
128
129
    ret.transA = transA;
    ret.Atype = A.data.dtype;
130
    ret.A_scale_inv = A.scale_inv.dptr;
131
    ret.lda = is_A_transposed ? k : m;
132
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
133
134
135
136
137
138
139
140
141
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
142
143
      }
    }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  } else if (nvfp4) {
    // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.

    if (is_A_transposed) {
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
      NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
                 "Input A has unsupported combination of recipe and layout");
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
    }
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
    ret.transA = CUBLAS_OP_T;  // NVFP4 gemm is only supported in TN layout.
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = k;
  } else if (mxfp8) {
    // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
161
162
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
163

164
    if (is_A_transposed) {
165
166
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
167
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
168
    }
169
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
170
    ret.transA = transA;
171
172
173
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
174
175
176
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
177
    if (is_A_transposed) {
178
179
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
180
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
181
    }
182
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
183
    ret.transA = CUBLAS_OP_T;
184
185
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
               "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
202
    ret.B = B.data.dptr;
203
204
    ret.transB = transB;
    ret.Btype = B.data.dtype;
205
    ret.B_scale_inv = B.scale_inv.dptr;
206
    ret.ldb = is_B_transposed ? n : k;
207
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
208
209
210
211
212
213
214
215
216
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
217
      }
218
    }
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  } else if (nvfp4) {
    if (is_B_transposed) {
      NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
                 "Input B has unsupported combination of recipe and layout");
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
    ret.transB = CUBLAS_OP_N;  // NVFP4 gemm is only supported in TN layout.
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = k;
  } else if (mxfp8) {
233
    if (is_B_transposed) {
234
235
236
237
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
238
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
239
    ret.transB = transB;
240
241
242
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
243
244
245
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
246
    if (is_B_transposed) {
247
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
248
    } else {
249
250
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
251
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
252
    ret.transB = CUBLAS_OP_N;
253
254
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
255
256
257
258
259
260
261
262
263
264
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
265
266
    }
  } else {
267
    NVTE_ERROR("B has unsupported scaling mode");
268
  }
269

270
271
272
  return ret;
}

273
274
275
276
277
278
279
/* cuBLAS version number at run-time */
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

280
281
}  // namespace

Przemek Tredak's avatar
Przemek Tredak committed
282
283
namespace transformer_engine {

284
285
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

286
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
287
288
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
289
                 const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
Jan Bielak's avatar
Jan Bielak committed
290
291
                 int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
                 cudaStream_t stream) {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

307
308
309
310
311
312
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

313
314
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

315
  void *C = outputD->data.dptr;
316
  void *D = outputD->data.dptr;
317
318
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
319
320
321
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
322
323
324
325
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
326
  const bool gelu = pre_gelu_out != nullptr;
327
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
  const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);

  // Update scaling factors with NVFP4 tensor scales
  // TODO: Check whether scales are on CPU/GPU or add API to control.
  // Currently scales are assumed to be on CPU when amax is provided
  // and on GPU when not provided, but this is brittle.
  if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
    // Reserve some workspace for alpha scale
    NVTE_CHECK(workspaceSize >= 4,
               "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
               workspaceSize, " bytes remaining.");
    workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
    uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
    float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);

    // Update alpha scale on device
    // Note: Compute NVFP4 tensor scales based on amaxes and then
    // divide from alpha scale. This way we only need to apply NVFP4
    // tensor scales in matmul output, instead of in matmul inputs.
    float old_alpha = *reinterpret_cast<const float *>(alpha);  // Assumed to be on CPU
    TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
    nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
                                        old_alpha, new_alpha_tensor.data(), stream);
    alpha = new_alpha_ptr;

    // Make sure beta scale is on device
    float old_beta = *reinterpret_cast<const float *>(beta);  // Assumed to be on CPU
    if (old_beta == 0) {
      beta = GetScalarZero();  // Device constant memory
    } else if (old_beta == 1) {
      beta = GetScalarOne();  // Device constant memory
    } else {
      // Move beta to workspace
      NVTE_CHECK(workspaceSize >= 4,
                 "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
                 workspaceSize, " bytes remaining.");
      workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
      float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
      set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
      NVTE_CHECK_CUDA(cudaGetLastError());
      beta = new_beta_ptr;
    }
  }
371
372
373

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
374
375
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
376

377
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
378
             "FP8 input to GEMM requires inverse of scale!");
379
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
380
             "FP8 input to GEMM requires inverse of scale!");
381
382
383
384
  NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
385

386
387
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
388
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
389
  if ((use_fp8 || use_fp4) && gelu) {
390
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
391
               "fp8 Aux output for gemm + gelu fusion not supported!");
392
  }
393
394
395
396
397
  if (is_fp4_dtype(outputD->data.dtype)) {
    NVTE_ERROR("FP4 GEMM output is not supported!");
  }
  if (use_fp4 && (D_type == CUDA_R_16F)) {
    NVTE_ERROR("FP4 GEMM does not support FP16 output!");
398
  }
Przemek Tredak's avatar
Przemek Tredak committed
399

400
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
401

402
403
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
404
  cublasLtMatmulPreference_t preference = nullptr;
405
  int returnedResults = 0;
406
407
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
408

409
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
410

411
412
413
414
415
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
416

417
  // Create matrix descriptors. Not setting any extra attributes.
418
419
420
421
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
422

423
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
424

425
426
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
427
                                                   &param.transA, sizeof(param.transA)));
428
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
429
                                                   &param.transB, sizeof(param.transB)));
430
431
  // Set math SM count
  if (math_sm_count != 0) {
432
433
434
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
435
436
  }

437
438
  // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
  // as appropriate. Note: gelu fusion isn't available right now, and we don't need
439
  // amax(D) either (next op is high precision).
440
441
442
443
444
  const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);

  if (use_fp8 || use_fp4) {
    // Fast accumulation is only supported for FP8.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
445
446
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
447
448

    // Scaling factors.
449
#if CUBLAS_VERSION >= 120800
450
451
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
452
#endif  // CUBLAS_VERSION >= 120800
453
    if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
454
455
456
457
458
459
460
461
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
462
#if CUBLAS_VERSION >= 120800
463
464
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
465
#endif  // CUBLAS_VERSION >= 120800
466
    } else if (mxfp8_gemm) {
467
468
469
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
470
471
472
473
474
475
476
477
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
478
479
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
480
481
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
482
      if (cublas_version() <= 120803) {
483
484
485
486
487
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
488
489
490
#else
      NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
#endif                     // CUBLAS_VERSION >= 120800
    } else if (use_fp4) {  // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
      // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
      cublasDataType_t scale_type = CUDA_R_32F;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));

      // Set pointer mode: alpha and beta are both device pointers
      // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
      cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));

      fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
      fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
      NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
519
#endif  // CUBLAS_VERSION >= 120800
520
521
522
523
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
524
525
526
527
#if CUBLAS_VERSION >= 120900
      NVTE_CHECK(cublas_version() >= 120900,
                 "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
                 cublas_version());
528
529
530
531
532
533
534
535
536
537
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
538
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
539
540
541
542
543
544
545
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
546
547
548
      NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120900
549
550
551
552
553
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

554
555
556
557
558
559
560
561
562
563
#if CUBLAS_VERSION >= 120800
    if (cublas_version() >= 120800) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                                       &scaling_mode_b, sizeof(scaling_mode_b)));
    }
#endif  // CUBLAS_VERSION >= 120800
564
565
566
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
567
568
569
570
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
571
572
573
574
575
576
577
578
579
#if CUBLAS_VERSION >= 120800
      if (cublas_version() >= 120800) {
        // NOTE: In all current cases where FP8 output is supported, the input is
        // scaled identically to the output.
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                         CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                         &scaling_mode_a, sizeof(scaling_mode_a)));
      }
#endif  // CUBLAS_VERSION >= 120800
580
581
582
583
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
584
585
586
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
587
    if (bias) {
588
589
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
590
    }
591
592
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
593
  }
Przemek Tredak's avatar
Przemek Tredak committed
594

595
596
597
598
599
600
601
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
602
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
603
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
604
605
606
607
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
608
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
609
610
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
611
612
613
614
615
616
617
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
618
619
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
620
621
622
623
624
625
626
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
627
628
629
630
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
631
632
633
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
634
  }
Przemek Tredak's avatar
Przemek Tredak committed
635

636
637
638
639
640
641
642
643
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

644
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
645
                                                   &epilogue, sizeof(epilogue)));
646

647
  if (counter != nullptr) {
648
649
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
    NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
650
               CUDA_VERSION);
651
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
652
    NVTE_ERROR(
653
        "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
654
        CUBLAS_VERSION);
655
#else
656
    NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
657
               "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
658
659
               cuda::cudart_version());
    NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
660
               "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
661
               cublas_version());
662
663
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
664
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
665
666
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
667
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
668
669
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
670
671
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
672
673
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
674
675
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
676
677
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
678
679
    }
#endif
680
  }
Przemek Tredak's avatar
Przemek Tredak committed
681

682
683
684
685
686
687
688
689
  // align the workspace to 256 B
  const int required_alignment = 256;
  const auto original_workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
  uint8_t *aligned_workspace_ptr =
      reinterpret_cast<uint8_t *>(workspace) + required_alignment - original_workspace_alignment;
  workspaceSize = workspaceSize - required_alignment + original_workspace_alignment;
  const auto new_workspace_alignment =
      _getAlignment(reinterpret_cast<uintptr_t>(aligned_workspace_ptr));
690
691
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
692
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
693
694
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
695
696
697
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
698
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
699
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
700
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
701
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
702
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
703
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
704
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
705
706
707
  NVTE_CHECK(new_workspace_alignment % 256 == 0,
             "cuBLAS workspace pointer must be aligned to 256 bytes, got ",
             new_workspace_alignment);
Przemek Tredak's avatar
Przemek Tredak committed
708

709
710
711
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
712
713
714
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
715
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
716

717
  // D = alpha * (A * B) + beta * C
718
719
720
721
722
723
724
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
                                   param.A,                      /* A */
                                   Adesc, param.B,               /* B */
                                   Bdesc, beta,                  /* beta */
                                   C,                            /* C */
                                   Cdesc, D,                     /* D */
                                   Ddesc, &heuristicResult.algo, /* algo */
725
                                   aligned_workspace_ptr,        /* workspace */
726
                                   workspaceSize, stream));      /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
727

728
  // Update FP8 scale-inv in output tensor
729
730
731
732
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
733
734
735
    update_tensor_scale_inv(outputD, stream);
  }

736
737
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
738
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
739
740
741
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
742
743
}

744
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
745

746
747
748
749
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
750
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
751
  using namespace transformer_engine;
752
753

  // Tensors
754
755
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
756
  Tensor *outputD = convertNVTETensorCheck(D);
757
758
759
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
760

761
762
763
764
765
766
767
768
769
770
771
  // Scales
  const float alpha = 1;
  const float beta = accumulate ? 1 : 0;

  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
772
773
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
                         const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
                         NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_gemm_v2);
  using namespace transformer_engine;

  // Data tensors
  const Tensor *A_tensor = convertNVTETensorCheck(A);
  const Tensor *B_tensor = convertNVTETensorCheck(B);
  const Tensor *C_tensor = convertNVTETensorCheck(C);
  Tensor *D_tensor = convertNVTETensorCheck(D);
  NVTE_CHECK(C_tensor == D_tensor,
             "Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");

  // Workspace
  void *workspace_ptr = nullptr;
  size_t workspace_size = 0;
  Tensor *workspace_tensor = convertNVTETensor(workspace);
  if (workspace_tensor != nullptr) {
    workspace_ptr = workspace_tensor->data.dptr;
    workspace_size =
        get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
  }

  // Additional config
  MatmulConfig config_;
  if (config != nullptr) {
    config_ = *reinterpret_cast<MatmulConfig *>(config);
  }

  // Configure GEMM epilogue
  const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
  if (with_grad_epilogue) {
    NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
               "Invalid epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ").");
  }
  Tensor dummy_tensor;
  Tensor *epilogue_bias_tensor = &dummy_tensor;
  if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
  } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
  }
  Tensor *epilogue_aux_tensor = &dummy_tensor;
  if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
    NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
               "Requested epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
    epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
  }

  // Launch GEMM
  cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
              transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
              with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
              config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
Jan Bielak's avatar
Jan Bielak committed
836
837
838
839
840
841
}

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
                             bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
842
  NVTE_API_CALL(nvte_cublas_gemm);
Jan Bielak's avatar
Jan Bielak committed
843
  using namespace transformer_engine;
844
845

  // Tensors
Jan Bielak's avatar
Jan Bielak committed
846
847
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
848
  Tensor *outputD = convertNVTETensorCheck(D);
Jan Bielak's avatar
Jan Bielak committed
849
850
851
852
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);

853
854
855
856
857
858
859
  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
Jan Bielak's avatar
Jan Bielak committed
860
861
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
862
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
863
864
}

865
866
867
868
869
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
870
871
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);
872
  using namespace transformer_engine;
873
874
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
  NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
875
             CUDA_VERSION);
876
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
877
878
879
  NVTE_ERROR(
      "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
      CUBLAS_VERSION);
880
#else
881
  NVTE_CHECK(
882
883
      transformer_engine::cuda::cudart_version() >= 12020 &&
          transformer_engine::cuda::cudart_version() < 13000,
884
      "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
885
      transformer_engine::cuda::cudart_version());
886
887
  NVTE_CHECK(
      cublas_version() >= 120205 && cublas_version() < 130000,
888
      "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
889
      cublas_version());
890

891
892
893
894
895
896
897
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  const Tensor *inputCounter = convertNVTETensor(counter);
  Tensor *wspace = convertNVTETensor(workspace);
898

899
900
901
  const void *alpha_ptr = GetScalarOne();
  const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();

902
903
904
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
905
906
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
907
908
909
              alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
              gemm_producer, inputCounter, stream);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
910
}
911

912
913
914
915
916
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                              const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                              bool transa, bool transb, bool grad, NVTETensor *workspace,
                              bool accumulate, bool use_split_accumulator, int math_sm_count,
                              cudaStream_t stream) {
917
  using namespace transformer_engine;
918
919

  int num_streams = nvte_get_num_compute_streams();
920

921
  int num_stream_used = std::min(num_streams, num_gemms);
922
  // wait for current stream to finish
923
  NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
924
  for (int s = 0; s < num_stream_used; s++) {
925
926
    NVTE_CHECK_CUDA(
        cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
927
928
  }

929
  for (int i = 0; i < num_gemms; i++) {
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    // Check whether GELU or dGELU epilogue is requested
    Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
    bool with_gelu_dgelu_epilogue =
        (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);

    // Construct config
    MatmulConfig config;
    if (grad) {
      config.dbias_tensor = bias[i];
      config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
    } else {
      config.bias_tensor = bias[i];
      config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
    }
    config.epilogue_aux_tensor = pre_gelu_out[i];
    config.use_split_accumulator = use_split_accumulator;
    config.sm_count = math_sm_count;

    // Launch GEMM
    const float alpha = 1.f;
    const float beta = accumulate ? 1.f : 0.f;
    nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
                        workspace[i % num_streams], &config,
                        detail::get_compute_stream(i % num_streams));
954
955
956
957
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
958
959
    NVTE_CHECK_CUDA(
        cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
960
961
962
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
963
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
964
965
  }
}
966

967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;

  // Deprecation warning
  NVTE_WARN(
      "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
      "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
      "applicable).");

  multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
                           accumulate, use_split_accumulator, math_sm_count, stream);
}

986
987
988
989
990
991
992
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                            const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                            bool transa, bool transb, bool grad, NVTETensor *workspace,
                            bool accumulate, bool use_split_accumulator, int math_sm_count,
                            cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_tensor_gemm);

  const int current_device = transformer_engine::cuda::current_device();
  const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
  const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
  const bool warn_fallback =
      transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);

  auto cublas_path = [&]() {
    multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  };

  // Currently only support cutlass group gemm on Hopper Arch
  if (!(is_hopper && use_cutlass)) {
    cublas_path();
    return;
  }

  auto is_empty_arr = [&](const NVTETensor *p) -> bool {
    if (p == nullptr) return true;
    for (int i = 0; i < num_gemms; ++i) {
      if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
    }
    return true;
  };

  auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
    int64_t ref_k = -1;
    for (size_t i = 0; i < num_gemms; i++) {
      const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
      const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];

      if ((k & 127) != 0) return false;

      if (ref_k < 0)
        ref_k = k;
      else if (k != ref_k)
        return false;
    }

    return true;
  };

  auto is_supported_dtype = [&]() -> bool {
    auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
    auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
    auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
    auto A_type = get_cuda_dtype(inputA->data.dtype);
    auto B_type = get_cuda_dtype(inputB->data.dtype);
    auto D_type = get_cuda_dtype(OutputD->data.dtype);

    return (A_type == B_type) && (A_type == D_type) &&
           ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
  };

  // CUTLASS Grouped GEMM fast path (SM90/TMA)
  // Conditions:
  //  - No fused epilogue: both bias and pre_gelu_out are empty.
  //  - Supported dtypes only: FP16/BF16 (FP32 accumulate).
  //  - Uniform K across groups and K % 128 == 0.
  //  - use_split_accumulator is ignored for FP16/BF16.
  //  - grad is irrelevant when bias/pre_gelu_out are empty.
  //
  // Otherwise, fall back to cuBLAS.
  if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
      all_groups_uniform_k128(B, transb)) {
    cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
                         current_device, math_sm_count, stream);
  } else {
    if (warn_fallback) {
      NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
    }
    cublas_path();
  }
}