cublaslt_gemm.cu 50.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
9
#include <cuda.h>
10
#include <transformer_engine/gemm.h>
11
#include <transformer_engine/multi_stream.h>
12
#include <transformer_engine/recipe.h>
13
14
#include <transformer_engine/transformer_engine.h>

15
#include <algorithm>
16
#include <cstdint>
17
#include <mutex>
18
#include <vector>
Tim Moon's avatar
Tim Moon committed
19

Przemek Tredak's avatar
Przemek Tredak committed
20
#include "../common.h"
21
#include "../util/cuda_runtime.h"
22
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
23
#include "../util/logging.h"
24
#include "../util/multi_stream.h"
25
26
#include "./config.h"
#include "./cutlass_grouped_gemm.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
namespace {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;

inline float *GetScalarOne() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float one = 1.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
  return dev_ptr;
}

inline float *GetScalarZero() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float zero = 0.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
  return dev_ptr;
}

__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }

61
62
63
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
64
  for (;; alignment /= 2) {
65
66
67
68
69
70
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

71
72
73
74
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

75
76
77
78
79
80
81
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
82
struct GemmParam {
83
84
85
86
87
88
89
90
91
92
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
93
94
};

95
96
97
98
99
100
101
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
102
103
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
104
                                int m, int n, int k) {
105
  using namespace transformer_engine;
106
107
108
109
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
110
111
      "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
          to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
112
113
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
114
115
116
  GemmParam ret;

  // Transpose mode with column-major ordering
117
118
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
119

120
121
122
123
  // Set conditions for MXFP8 and NVFP4 gemm execution.
  const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
  const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);

124
  // Configure A matrix
125
  if (is_tensor_scaling(A.scaling_mode)) {
126
    // Unscaled or FP8 tensor scaling
127
    ret.A = A.data.dptr;
128
129
    ret.transA = transA;
    ret.Atype = A.data.dtype;
130
    ret.A_scale_inv = A.scale_inv.dptr;
131
    ret.lda = is_A_transposed ? k : m;
132
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
133
134
135
136
137
138
139
140
141
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
142
143
      }
    }
144
145
146
147
148
149

    if (is_fp8_dtype(ret.Atype)) {
      // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
      NVTE_CHECK(ret.lda % 16 == 0,
                 "Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
    }
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
  } else if (nvfp4) {
    // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.

    if (is_A_transposed) {
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
      NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
                 "Input A has unsupported combination of recipe and layout");
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
    }
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
    ret.transA = CUBLAS_OP_T;  // NVFP4 gemm is only supported in TN layout.
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = k;
  } else if (mxfp8) {
    // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
167
168
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
169

170
    if (is_A_transposed) {
171
172
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
173
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
174
    }
175
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
176
    ret.transA = transA;
177
178
179
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
180
181
182
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
183
    if (is_A_transposed) {
184
185
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
186
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
187
    }
188
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
189
    ret.transA = CUBLAS_OP_T;
190
191
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
192
193
194
195
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
196
               "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
197
198
199
200
201
202
203
204
205
206
207
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
208
    ret.B = B.data.dptr;
209
210
    ret.transB = transB;
    ret.Btype = B.data.dtype;
211
    ret.B_scale_inv = B.scale_inv.dptr;
212
    ret.ldb = is_B_transposed ? n : k;
213
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
214
215
216
217
218
219
220
221
222
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
223
      }
224
    }
225
226
227
228
229
230

    if (is_fp8_dtype(ret.Atype)) {
      // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
      NVTE_CHECK(ret.ldb % 16 == 0,
                 "Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
    }
231
232
233
234
235
236
237
238
239
240
241
242
243
244
  } else if (nvfp4) {
    if (is_B_transposed) {
      NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
                 "Input B has unsupported combination of recipe and layout");
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
    ret.transB = CUBLAS_OP_N;  // NVFP4 gemm is only supported in TN layout.
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = k;
  } else if (mxfp8) {
245
    if (is_B_transposed) {
246
247
248
249
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
250
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
251
    ret.transB = transB;
252
253
254
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
255
256
257
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
258
    if (is_B_transposed) {
259
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
260
    } else {
261
262
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
263
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
264
    ret.transB = CUBLAS_OP_N;
265
266
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
267
268
269
270
271
272
273
274
275
276
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
277
278
    }
  } else {
279
    NVTE_ERROR("B has unsupported scaling mode");
280
  }
281

282
283
284
  return ret;
}

285
286
287
288
289
290
291
/* cuBLAS version number at run-time */
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

292
293
}  // namespace

Przemek Tredak's avatar
Przemek Tredak committed
294
295
namespace transformer_engine {

296
297
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

298
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
299
300
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
301
                 const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
Jan Bielak's avatar
Jan Bielak committed
302
303
                 int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
                 cudaStream_t stream) {
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

319
320
321
322
323
324
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

325
326
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

327
  void *C = outputD->data.dptr;
328
  void *D = outputD->data.dptr;
329
330
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
331
332
333
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
334
335
336
337
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
338
  const bool gelu = pre_gelu_out != nullptr;
339
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
  const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);

  // Update scaling factors with NVFP4 tensor scales
  // TODO: Check whether scales are on CPU/GPU or add API to control.
  // Currently scales are assumed to be on CPU when amax is provided
  // and on GPU when not provided, but this is brittle.
  if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
    // Reserve some workspace for alpha scale
    NVTE_CHECK(workspaceSize >= 4,
               "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
               workspaceSize, " bytes remaining.");
    workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
    uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
    float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);

    // Update alpha scale on device
    // Note: Compute NVFP4 tensor scales based on amaxes and then
    // divide from alpha scale. This way we only need to apply NVFP4
    // tensor scales in matmul output, instead of in matmul inputs.
    float old_alpha = *reinterpret_cast<const float *>(alpha);  // Assumed to be on CPU
    TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
    nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
                                        old_alpha, new_alpha_tensor.data(), stream);
    alpha = new_alpha_ptr;

    // Make sure beta scale is on device
    float old_beta = *reinterpret_cast<const float *>(beta);  // Assumed to be on CPU
    if (old_beta == 0) {
      beta = GetScalarZero();  // Device constant memory
    } else if (old_beta == 1) {
      beta = GetScalarOne();  // Device constant memory
    } else {
      // Move beta to workspace
      NVTE_CHECK(workspaceSize >= 4,
                 "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
                 workspaceSize, " bytes remaining.");
      workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
      float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
      set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
      NVTE_CHECK_CUDA(cudaGetLastError());
      beta = new_beta_ptr;
    }
  }
383
384
385

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
386
387
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
388

389
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
390
             "FP8 input to GEMM requires inverse of scale!");
391
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
392
             "FP8 input to GEMM requires inverse of scale!");
393
394
395
396
  NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
397

398
399
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
400
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
401
  if ((use_fp8 || use_fp4) && gelu) {
402
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
403
               "fp8 Aux output for gemm + gelu fusion not supported!");
404
  }
405
406
407
408
409
  if (is_fp4_dtype(outputD->data.dtype)) {
    NVTE_ERROR("FP4 GEMM output is not supported!");
  }
  if (use_fp4 && (D_type == CUDA_R_16F)) {
    NVTE_ERROR("FP4 GEMM does not support FP16 output!");
410
  }
Przemek Tredak's avatar
Przemek Tredak committed
411

412
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
413

414
415
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
416
  cublasLtMatmulPreference_t preference = nullptr;
417
  int returnedResults = 0;
418
419
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
420

421
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
422

423
424
425
426
427
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
428

429
  // Create matrix descriptors. Not setting any extra attributes.
430
431
432
433
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
434

435
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
436

437
438
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
439
                                                   &param.transA, sizeof(param.transA)));
440
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
441
                                                   &param.transB, sizeof(param.transB)));
442
443
  // Set math SM count
  if (math_sm_count != 0) {
444
445
446
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
447
448
  }

449
450
  // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
  // as appropriate. Note: gelu fusion isn't available right now, and we don't need
451
  // amax(D) either (next op is high precision).
452
453
454
455
456
  const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);

  if (use_fp8 || use_fp4) {
    // Fast accumulation is only supported for FP8.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
457
458
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
459
460

    // Scaling factors.
461
#if CUBLAS_VERSION >= 120800
462
463
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
464
#endif  // CUBLAS_VERSION >= 120800
465
    if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
466
467
468
469
470
471
472
473
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
474
#if CUBLAS_VERSION >= 120800
475
476
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
477
#endif  // CUBLAS_VERSION >= 120800
478
    } else if (mxfp8_gemm) {
479
480
481
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
482
483
484
485
486
487
488
489
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
490
491
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
492
493
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
494
      if (cublas_version() <= 120803) {
495
496
497
498
499
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
500
501
502
#else
      NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
#endif                     // CUBLAS_VERSION >= 120800
    } else if (use_fp4) {  // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
      // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
      cublasDataType_t scale_type = CUDA_R_32F;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));

      // Set pointer mode: alpha and beta are both device pointers
      // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
      cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));

      fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
      fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
      NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
531
#endif  // CUBLAS_VERSION >= 120800
532
533
534
535
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
536
537
538
539
#if CUBLAS_VERSION >= 120900
      NVTE_CHECK(cublas_version() >= 120900,
                 "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
                 cublas_version());
540
541
542
543
544
545
546
547
548
549
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
550
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
551
552
553
554
555
556
557
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
558
559
560
      NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120900
561
562
563
564
565
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

566
567
568
569
570
571
572
573
574
575
#if CUBLAS_VERSION >= 120800
    if (cublas_version() >= 120800) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                                       &scaling_mode_b, sizeof(scaling_mode_b)));
    }
#endif  // CUBLAS_VERSION >= 120800
576
577
578
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
579
580
581
582
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
583
584
585
586
587
588
589
590
591
#if CUBLAS_VERSION >= 120800
      if (cublas_version() >= 120800) {
        // NOTE: In all current cases where FP8 output is supported, the input is
        // scaled identically to the output.
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                         CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                         &scaling_mode_a, sizeof(scaling_mode_a)));
      }
#endif  // CUBLAS_VERSION >= 120800
592
593
594
595
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
596
597
598
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
599
    if (bias) {
600
601
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
602
    }
603
604
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
605
  }
Przemek Tredak's avatar
Przemek Tredak committed
606

607
608
609
610
611
612
613
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
614
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
615
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
616
617
618
619
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
620
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
621
622
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
623
624
625
626
627
628
629
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
630
631
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
632
633
634
635
636
637
638
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
639
640
641
642
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
643
644
645
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
646
  }
Przemek Tredak's avatar
Przemek Tredak committed
647

648
649
650
651
652
653
654
655
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

656
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
657
                                                   &epilogue, sizeof(epilogue)));
658

659
  if (counter != nullptr) {
660
661
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
    NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
662
               CUDA_VERSION);
663
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
664
    NVTE_ERROR(
665
        "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
666
        CUBLAS_VERSION);
667
#else
668
    NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
669
               "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
670
671
               cuda::cudart_version());
    NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
672
               "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
673
               cublas_version());
674
675
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
676
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
677
678
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
679
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
680
681
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
682
683
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
684
685
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
686
687
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
688
689
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
690
691
    }
#endif
692
  }
Przemek Tredak's avatar
Przemek Tredak committed
693

694
695
696
697
698
699
700
701
  // align the workspace to 256 B
  const int required_alignment = 256;
  const auto original_workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
  uint8_t *aligned_workspace_ptr =
      reinterpret_cast<uint8_t *>(workspace) + required_alignment - original_workspace_alignment;
  workspaceSize = workspaceSize - required_alignment + original_workspace_alignment;
  const auto new_workspace_alignment =
      _getAlignment(reinterpret_cast<uintptr_t>(aligned_workspace_ptr));
702
703
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
704
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
705
706
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
707
708
709
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
710
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
711
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
712
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
713
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
714
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
715
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
716
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
717
718
719
  NVTE_CHECK(new_workspace_alignment % 256 == 0,
             "cuBLAS workspace pointer must be aligned to 256 bytes, got ",
             new_workspace_alignment);
Przemek Tredak's avatar
Przemek Tredak committed
720

721
722
723
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
724
725
726
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
727
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
728

729
  // D = alpha * (A * B) + beta * C
730
731
732
733
734
735
736
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
                                   param.A,                      /* A */
                                   Adesc, param.B,               /* B */
                                   Bdesc, beta,                  /* beta */
                                   C,                            /* C */
                                   Cdesc, D,                     /* D */
                                   Ddesc, &heuristicResult.algo, /* algo */
737
                                   aligned_workspace_ptr,        /* workspace */
738
                                   workspaceSize, stream));      /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
739

740
  // Update FP8 scale-inv in output tensor
741
742
743
744
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
745
746
747
    update_tensor_scale_inv(outputD, stream);
  }

748
749
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
750
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
751
752
753
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
754
755
}

756
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
757

758
759
760
761
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
762
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
763
  using namespace transformer_engine;
764
765

  // Tensors
766
767
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
768
  Tensor *outputD = convertNVTETensorCheck(D);
769
770
771
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
772

773
774
775
776
777
778
779
780
781
782
783
  // Scales
  const float alpha = 1;
  const float beta = accumulate ? 1 : 0;

  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
784
785
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
                         const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
                         NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_gemm_v2);
  using namespace transformer_engine;

  // Data tensors
  const Tensor *A_tensor = convertNVTETensorCheck(A);
  const Tensor *B_tensor = convertNVTETensorCheck(B);
  const Tensor *C_tensor = convertNVTETensorCheck(C);
  Tensor *D_tensor = convertNVTETensorCheck(D);
  NVTE_CHECK(C_tensor == D_tensor,
             "Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");

  // Workspace
  void *workspace_ptr = nullptr;
  size_t workspace_size = 0;
  Tensor *workspace_tensor = convertNVTETensor(workspace);
  if (workspace_tensor != nullptr) {
    workspace_ptr = workspace_tensor->data.dptr;
    workspace_size =
        get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
  }

  // Additional config
  MatmulConfig config_;
  if (config != nullptr) {
    config_ = *reinterpret_cast<MatmulConfig *>(config);
  }

  // Configure GEMM epilogue
  const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
  if (with_grad_epilogue) {
    NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
               "Invalid epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ").");
  }
  Tensor dummy_tensor;
  Tensor *epilogue_bias_tensor = &dummy_tensor;
  if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
  } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
  }
  Tensor *epilogue_aux_tensor = &dummy_tensor;
  if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
    NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
               "Requested epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
    epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
  }

  // Launch GEMM
  cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
              transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
              with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
              config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
Jan Bielak's avatar
Jan Bielak committed
848
849
850
851
852
853
}

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
                             bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
854
  NVTE_API_CALL(nvte_cublas_gemm);
Jan Bielak's avatar
Jan Bielak committed
855
  using namespace transformer_engine;
856
857

  // Tensors
Jan Bielak's avatar
Jan Bielak committed
858
859
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
860
  Tensor *outputD = convertNVTETensorCheck(D);
Jan Bielak's avatar
Jan Bielak committed
861
862
863
864
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);

865
866
867
868
869
870
871
  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
Jan Bielak's avatar
Jan Bielak committed
872
873
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
874
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
875
876
}

877
878
879
880
881
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
882
883
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);
884
  using namespace transformer_engine;
885
886
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
  NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
887
             CUDA_VERSION);
888
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
889
890
891
  NVTE_ERROR(
      "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
      CUBLAS_VERSION);
892
#else
893
  NVTE_CHECK(
894
895
      transformer_engine::cuda::cudart_version() >= 12020 &&
          transformer_engine::cuda::cudart_version() < 13000,
896
      "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
897
      transformer_engine::cuda::cudart_version());
898
899
  NVTE_CHECK(
      cublas_version() >= 120205 && cublas_version() < 130000,
900
      "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
901
      cublas_version());
902

903
904
905
906
907
908
909
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  const Tensor *inputCounter = convertNVTETensor(counter);
  Tensor *wspace = convertNVTETensor(workspace);
910

911
912
913
  const void *alpha_ptr = GetScalarOne();
  const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();

914
915
916
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
917
918
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
919
920
921
              alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
              gemm_producer, inputCounter, stream);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
922
}
923

924
925
926
927
928
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                              const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                              bool transa, bool transb, bool grad, NVTETensor *workspace,
                              bool accumulate, bool use_split_accumulator, int math_sm_count,
                              cudaStream_t stream) {
929
  using namespace transformer_engine;
930
931

  int num_streams = nvte_get_num_compute_streams();
932

933
  int num_stream_used = std::min(num_streams, num_gemms);
934
  // wait for current stream to finish
935
  NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
936
  for (int s = 0; s < num_stream_used; s++) {
937
938
    NVTE_CHECK_CUDA(
        cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
939
940
  }

941
  for (int i = 0; i < num_gemms; i++) {
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
    // Check whether GELU or dGELU epilogue is requested
    Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
    bool with_gelu_dgelu_epilogue =
        (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);

    // Construct config
    MatmulConfig config;
    if (grad) {
      config.dbias_tensor = bias[i];
      config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
    } else {
      config.bias_tensor = bias[i];
      config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
    }
    config.epilogue_aux_tensor = pre_gelu_out[i];
    config.use_split_accumulator = use_split_accumulator;
    config.sm_count = math_sm_count;

    // Launch GEMM
    const float alpha = 1.f;
    const float beta = accumulate ? 1.f : 0.f;
    nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
                        workspace[i % num_streams], &config,
                        detail::get_compute_stream(i % num_streams));
966
967
968
969
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
970
971
    NVTE_CHECK_CUDA(
        cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
972
973
974
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
975
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
976
977
  }
}
978

979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;

  // Deprecation warning
  NVTE_WARN(
      "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
      "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
      "applicable).");

  multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
                           accumulate, use_split_accumulator, math_sm_count, stream);
}

998
999
1000
1001
1002
1003
1004
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                            const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                            bool transa, bool transb, bool grad, NVTETensor *workspace,
                            bool accumulate, bool use_split_accumulator, int math_sm_count,
                            cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_tensor_gemm);

  const int current_device = transformer_engine::cuda::current_device();
  const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
  const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
  const bool warn_fallback =
      transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);

  auto cublas_path = [&]() {
    multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  };

  // Currently only support cutlass group gemm on Hopper Arch
  if (!(is_hopper && use_cutlass)) {
    cublas_path();
    return;
  }

  auto is_empty_arr = [&](const NVTETensor *p) -> bool {
    if (p == nullptr) return true;
    for (int i = 0; i < num_gemms; ++i) {
      if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
    }
    return true;
  };

  auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
    int64_t ref_k = -1;
    for (size_t i = 0; i < num_gemms; i++) {
      const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
      const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];

      if ((k & 127) != 0) return false;

      if (ref_k < 0)
        ref_k = k;
      else if (k != ref_k)
        return false;
    }

    return true;
  };

  auto is_supported_dtype = [&]() -> bool {
    auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
    auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
    auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
    auto A_type = get_cuda_dtype(inputA->data.dtype);
    auto B_type = get_cuda_dtype(inputB->data.dtype);
    auto D_type = get_cuda_dtype(OutputD->data.dtype);

    return (A_type == B_type) && (A_type == D_type) &&
           ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
  };

  // CUTLASS Grouped GEMM fast path (SM90/TMA)
  // Conditions:
  //  - No fused epilogue: both bias and pre_gelu_out are empty.
  //  - Supported dtypes only: FP16/BF16 (FP32 accumulate).
  //  - Uniform K across groups and K % 128 == 0.
  //  - use_split_accumulator is ignored for FP16/BF16.
  //  - grad is irrelevant when bias/pre_gelu_out are empty.
  //
  // Otherwise, fall back to cuBLAS.
  if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
      all_groups_uniform_k128(B, transb)) {
    cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
                         current_device, math_sm_count, stream);
  } else {
    if (warn_fallback) {
      NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
    }
    cublas_path();
  }
}