cublaslt_gemm.cu 52.1 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
9
#include <cuda.h>
10
#include <transformer_engine/gemm.h>
11
#include <transformer_engine/multi_stream.h>
12
#include <transformer_engine/recipe.h>
13
14
#include <transformer_engine/transformer_engine.h>

15
#include <algorithm>
16
#include <cstdint>
17
#include <mutex>
18
#include <vector>
Tim Moon's avatar
Tim Moon committed
19

Przemek Tredak's avatar
Przemek Tredak committed
20
#include "../common.h"
21
#include "../util/cuda_runtime.h"
22
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
23
#include "../util/logging.h"
24
#include "../util/multi_stream.h"
25
26
#include "./config.h"
#include "./cutlass_grouped_gemm.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
namespace {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;

inline float *GetScalarOne() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float one = 1.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
  return dev_ptr;
}

inline float *GetScalarZero() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float zero = 0.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
  return dev_ptr;
}

__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }

61
62
63
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
64
  for (;; alignment /= 2) {
65
66
67
68
69
70
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

71
72
73
74
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

75
76
77
78
79
80
81
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
82
struct GemmParam {
83
84
85
86
87
88
89
90
91
92
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
93
94
};

95
96
97
98
99
100
101
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
102
103
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
104
                                int m, int n, int k) {
105
  using namespace transformer_engine;
106
107
108
109
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
110
111
      "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
          to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
112
113
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
114
115
116
  GemmParam ret;

  // Transpose mode with column-major ordering
117
118
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
119

120
121
122
123
  // Set conditions for MXFP8 and NVFP4 gemm execution.
  const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
  const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);

124
  // Configure A matrix
125
  if (is_tensor_scaling(A.scaling_mode)) {
126
    // Unscaled or FP8 tensor scaling
127
    ret.A = A.data.dptr;
128
129
    ret.transA = transA;
    ret.Atype = A.data.dtype;
130
    ret.A_scale_inv = A.scale_inv.dptr;
131
    ret.lda = is_A_transposed ? k : m;
132
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
133
134
135
136
137
138
139
140
141
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
142
      }
143
144
145
146
147
148
149
150
151
152
    } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
      // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
      // data  with the mirrored transpose-flag if we don't have row-wise data.
      NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
                 "Input A is missing column-wise usage");
      ret.A = A.columnwise_data.dptr;
      ret.transA = is_A_transposed ? CUBLAS_OP_N : CUBLAS_OP_T;
      ret.Atype = A.columnwise_data.dtype;
      ret.A_scale_inv = A.columnwise_scale_inv.dptr;
      ret.lda = is_A_transposed ? m : k;
153
    }
154
155
156
157
158
159

    if (is_fp8_dtype(ret.Atype)) {
      // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
      NVTE_CHECK(ret.lda % 16 == 0,
                 "Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
    }
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  } else if (nvfp4) {
    // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.

    if (is_A_transposed) {
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
      NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
                 "Input A has unsupported combination of recipe and layout");
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
    }
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
    ret.transA = CUBLAS_OP_T;  // NVFP4 gemm is only supported in TN layout.
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = k;
  } else if (mxfp8) {
    // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
177
178
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
179

180
    if (is_A_transposed) {
181
182
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
183
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
184
    }
185
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
186
    ret.transA = transA;
187
188
189
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
190
191
192
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
193
    if (is_A_transposed) {
194
195
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
196
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
197
    }
198
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
199
    ret.transA = CUBLAS_OP_T;
200
201
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
202
203
204
205
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
206
               "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
207
208
209
210
211
212
213
214
215
216
217
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
218
    ret.B = B.data.dptr;
219
220
    ret.transB = transB;
    ret.Btype = B.data.dtype;
221
    ret.B_scale_inv = B.scale_inv.dptr;
222
    ret.ldb = is_B_transposed ? n : k;
223
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
224
225
226
227
228
229
230
231
232
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
233
      }
234
235
236
237
238
239
240
241
242
243
    } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
      // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
      // data with the mirrored transpose-flag if we don't have row-wise data.
      NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
                 "Input B is missing column-wise usage");
      ret.B = B.columnwise_data.dptr;
      ret.transB = is_B_transposed ? CUBLAS_OP_N : CUBLAS_OP_T;
      ret.Btype = B.columnwise_data.dtype;
      ret.B_scale_inv = B.columnwise_scale_inv.dptr;
      ret.ldb = is_B_transposed ? k : n;
244
    }
245
246
247
248
249
250

    if (is_fp8_dtype(ret.Atype)) {
      // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
      NVTE_CHECK(ret.ldb % 16 == 0,
                 "Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
    }
251
252
253
254
255
256
257
258
259
260
261
262
263
264
  } else if (nvfp4) {
    if (is_B_transposed) {
      NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
                 "Input B has unsupported combination of recipe and layout");
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
    ret.transB = CUBLAS_OP_N;  // NVFP4 gemm is only supported in TN layout.
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = k;
  } else if (mxfp8) {
265
    if (is_B_transposed) {
266
267
268
269
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
270
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
271
    ret.transB = transB;
272
273
274
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
275
276
277
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
278
    if (is_B_transposed) {
279
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
280
    } else {
281
282
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
283
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
284
    ret.transB = CUBLAS_OP_N;
285
286
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
287
288
289
290
291
292
293
294
295
296
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
297
298
    }
  } else {
299
    NVTE_ERROR("B has unsupported scaling mode");
300
  }
301

302
303
304
  return ret;
}

305
306
307
308
309
310
311
/* cuBLAS version number at run-time */
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

312
313
}  // namespace

Przemek Tredak's avatar
Przemek Tredak committed
314
315
namespace transformer_engine {

316
317
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

318
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
319
320
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
321
                 const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
Jan Bielak's avatar
Jan Bielak committed
322
323
                 int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
                 cudaStream_t stream) {
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

339
340
341
342
343
344
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

345
346
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

347
  void *C = outputD->data.dptr;
348
  void *D = outputD->data.dptr;
349
350
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
351
352
353
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
354
355
356
357
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
358
  const bool gelu = pre_gelu_out != nullptr;
359
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
360
361
362
363
364
365
  const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);

  // Update scaling factors with NVFP4 tensor scales
  // TODO: Check whether scales are on CPU/GPU or add API to control.
  // Currently scales are assumed to be on CPU when amax is provided
  // and on GPU when not provided, but this is brittle.
366
367
368
  if (use_fp4 &&
      ((transa == CUBLAS_OP_T ? inputA->amax.dptr : inputA->columnwise_amax.dptr) != nullptr ||
       (transb == CUBLAS_OP_T ? inputB->columnwise_amax.dptr : inputB->amax.dptr) != nullptr)) {
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    // Reserve some workspace for alpha scale
    NVTE_CHECK(workspaceSize >= 4,
               "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
               workspaceSize, " bytes remaining.");
    workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
    uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
    float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);

    // Update alpha scale on device
    // Note: Compute NVFP4 tensor scales based on amaxes and then
    // divide from alpha scale. This way we only need to apply NVFP4
    // tensor scales in matmul output, instead of in matmul inputs.
    float old_alpha = *reinterpret_cast<const float *>(alpha);  // Assumed to be on CPU
    TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
383
384
385
386
    bool a_rowwise_amax = transa == CUBLAS_OP_T;
    bool b_rowwise_amax = transb != CUBLAS_OP_T;
    nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, a_rowwise_amax, inputB->nvte_tensor,
                                        b_rowwise_amax, old_alpha, new_alpha_tensor.data(), stream);
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    alpha = new_alpha_ptr;

    // Make sure beta scale is on device
    float old_beta = *reinterpret_cast<const float *>(beta);  // Assumed to be on CPU
    if (old_beta == 0) {
      beta = GetScalarZero();  // Device constant memory
    } else if (old_beta == 1) {
      beta = GetScalarOne();  // Device constant memory
    } else {
      // Move beta to workspace
      NVTE_CHECK(workspaceSize >= 4,
                 "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
                 workspaceSize, " bytes remaining.");
      workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
      float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
      set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
      NVTE_CHECK_CUDA(cudaGetLastError());
      beta = new_beta_ptr;
    }
  }
407
408
409

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
410
411
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
412

413
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
414
             "FP8 input to GEMM requires inverse of scale!");
415
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
416
             "FP8 input to GEMM requires inverse of scale!");
417
418
419
420
  NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
421

422
423
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
424
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
425
  if ((use_fp8 || use_fp4) && gelu) {
426
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
427
               "fp8 Aux output for gemm + gelu fusion not supported!");
428
  }
429
430
431
432
433
  if (is_fp4_dtype(outputD->data.dtype)) {
    NVTE_ERROR("FP4 GEMM output is not supported!");
  }
  if (use_fp4 && (D_type == CUDA_R_16F)) {
    NVTE_ERROR("FP4 GEMM does not support FP16 output!");
434
  }
Przemek Tredak's avatar
Przemek Tredak committed
435

436
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
437

438
439
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
440
  cublasLtMatmulPreference_t preference = nullptr;
441
  int returnedResults = 0;
442
443
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
444

445
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
446

447
448
449
450
451
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
452

453
  // Create matrix descriptors. Not setting any extra attributes.
454
455
456
457
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
458

459
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
460

461
462
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
463
                                                   &param.transA, sizeof(param.transA)));
464
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
465
                                                   &param.transB, sizeof(param.transB)));
466
467
  // Set math SM count
  if (math_sm_count != 0) {
468
469
470
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
471
472
  }

473
474
  // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
  // as appropriate. Note: gelu fusion isn't available right now, and we don't need
475
  // amax(D) either (next op is high precision).
476
477
478
479
480
  const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);

  if (use_fp8 || use_fp4) {
    // Fast accumulation is only supported for FP8.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
481
482
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
483
484

    // Scaling factors.
485
#if CUBLAS_VERSION >= 120800
486
487
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
488
#endif  // CUBLAS_VERSION >= 120800
489
    if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
490
491
492
493
494
495
496
497
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
498
#if CUBLAS_VERSION >= 120800
499
500
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
501
#endif  // CUBLAS_VERSION >= 120800
502
    } else if (mxfp8_gemm) {
503
504
505
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
506
507
508
509
510
511
512
513
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
514
515
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
516
517
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
518
      if (cublas_version() <= 120803) {
519
520
521
522
523
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
524
525
526
#else
      NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
#endif                     // CUBLAS_VERSION >= 120800
    } else if (use_fp4) {  // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
      // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
      cublasDataType_t scale_type = CUDA_R_32F;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));

      // Set pointer mode: alpha and beta are both device pointers
      // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
      cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));

      fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
      fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
      NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
555
#endif  // CUBLAS_VERSION >= 120800
556
557
558
559
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
560
561
562
563
#if CUBLAS_VERSION >= 120900
      NVTE_CHECK(cublas_version() >= 120900,
                 "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
                 cublas_version());
564
565
566
567
568
569
570
571
572
573
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
574
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
575
576
577
578
579
580
581
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
582
583
584
      NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120900
585
586
587
588
589
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

590
591
592
593
594
595
596
597
598
599
#if CUBLAS_VERSION >= 120800
    if (cublas_version() >= 120800) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                                       &scaling_mode_b, sizeof(scaling_mode_b)));
    }
#endif  // CUBLAS_VERSION >= 120800
600
601
602
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
603
604
605
606
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
607
608
609
610
611
612
613
614
615
#if CUBLAS_VERSION >= 120800
      if (cublas_version() >= 120800) {
        // NOTE: In all current cases where FP8 output is supported, the input is
        // scaled identically to the output.
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                         CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                         &scaling_mode_a, sizeof(scaling_mode_a)));
      }
#endif  // CUBLAS_VERSION >= 120800
616
617
618
619
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
620
621
622
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
623
    if (bias) {
624
625
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
626
    }
627
628
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
629
  }
Przemek Tredak's avatar
Przemek Tredak committed
630

631
632
633
634
635
636
637
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
638
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
639
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
640
641
642
643
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
644
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
645
646
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
647
648
649
650
651
652
653
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
654
655
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
656
657
658
659
660
661
662
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
663
664
665
666
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
667
668
669
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
670
  }
Przemek Tredak's avatar
Przemek Tredak committed
671

672
673
674
675
676
677
678
679
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

680
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
681
                                                   &epilogue, sizeof(epilogue)));
682

683
  if (counter != nullptr) {
684
685
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
    NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
686
               CUDA_VERSION);
687
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
688
    NVTE_ERROR(
689
        "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
690
        CUBLAS_VERSION);
691
#else
692
    NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
693
               "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
694
695
               cuda::cudart_version());
    NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
696
               "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
697
               cublas_version());
698
699
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
700
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
701
702
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
703
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
704
705
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
706
707
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
708
709
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
710
711
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
712
713
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
714
715
    }
#endif
716
  }
Przemek Tredak's avatar
Przemek Tredak committed
717

718
719
720
721
722
723
724
725
  // align the workspace to 256 B
  const int required_alignment = 256;
  const auto original_workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
  uint8_t *aligned_workspace_ptr =
      reinterpret_cast<uint8_t *>(workspace) + required_alignment - original_workspace_alignment;
  workspaceSize = workspaceSize - required_alignment + original_workspace_alignment;
  const auto new_workspace_alignment =
      _getAlignment(reinterpret_cast<uintptr_t>(aligned_workspace_ptr));
726
727
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
728
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
729
730
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
731
732
733
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
734
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
735
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
736
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
737
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
738
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
739
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
740
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
741
742
743
  NVTE_CHECK(new_workspace_alignment % 256 == 0,
             "cuBLAS workspace pointer must be aligned to 256 bytes, got ",
             new_workspace_alignment);
Przemek Tredak's avatar
Przemek Tredak committed
744

745
746
747
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
748
749
750
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
751
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
752

753
  // D = alpha * (A * B) + beta * C
754
755
756
757
758
759
760
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
                                   param.A,                      /* A */
                                   Adesc, param.B,               /* B */
                                   Bdesc, beta,                  /* beta */
                                   C,                            /* C */
                                   Cdesc, D,                     /* D */
                                   Ddesc, &heuristicResult.algo, /* algo */
761
                                   aligned_workspace_ptr,        /* workspace */
762
                                   workspaceSize, stream));      /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
763

764
  // Update FP8 scale-inv in output tensor
765
766
767
768
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
769
770
771
    update_tensor_scale_inv(outputD, stream);
  }

772
773
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
774
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
775
776
777
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
778
779
}

780
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
781

782
783
784
785
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
786
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
787
  using namespace transformer_engine;
788
789

  // Tensors
790
791
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
792
  Tensor *outputD = convertNVTETensorCheck(D);
793
794
795
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
796

797
798
799
800
801
802
803
804
805
806
807
  // Scales
  const float alpha = 1;
  const float beta = accumulate ? 1 : 0;

  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
808
809
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
                         const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
                         NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_gemm_v2);
  using namespace transformer_engine;

  // Data tensors
  const Tensor *A_tensor = convertNVTETensorCheck(A);
  const Tensor *B_tensor = convertNVTETensorCheck(B);
  const Tensor *C_tensor = convertNVTETensorCheck(C);
  Tensor *D_tensor = convertNVTETensorCheck(D);
  NVTE_CHECK(C_tensor == D_tensor,
             "Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");

  // Workspace
  void *workspace_ptr = nullptr;
  size_t workspace_size = 0;
  Tensor *workspace_tensor = convertNVTETensor(workspace);
  if (workspace_tensor != nullptr) {
    workspace_ptr = workspace_tensor->data.dptr;
    workspace_size =
        get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
  }

  // Additional config
  MatmulConfig config_;
  if (config != nullptr) {
    config_ = *reinterpret_cast<MatmulConfig *>(config);
  }

  // Configure GEMM epilogue
  const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
  if (with_grad_epilogue) {
    NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
               "Invalid epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ").");
  }
  Tensor dummy_tensor;
  Tensor *epilogue_bias_tensor = &dummy_tensor;
  if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
  } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
  }
  Tensor *epilogue_aux_tensor = &dummy_tensor;
  if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
    NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
               "Requested epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
    epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
  }

  // Launch GEMM
  cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
              transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
              with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
              config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
Jan Bielak's avatar
Jan Bielak committed
872
873
874
875
876
877
}

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
                             bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
878
  NVTE_API_CALL(nvte_cublas_gemm);
Jan Bielak's avatar
Jan Bielak committed
879
  using namespace transformer_engine;
880
881

  // Tensors
Jan Bielak's avatar
Jan Bielak committed
882
883
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
884
  Tensor *outputD = convertNVTETensorCheck(D);
Jan Bielak's avatar
Jan Bielak committed
885
886
887
888
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);

889
890
891
892
893
894
895
  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }

  // Launch GEMM
Jan Bielak's avatar
Jan Bielak committed
896
897
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
898
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
899
900
}

901
902
903
904
905
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
906
907
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);
908
  using namespace transformer_engine;
909
910
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
  NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
911
             CUDA_VERSION);
912
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
913
914
915
  NVTE_ERROR(
      "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
      CUBLAS_VERSION);
916
#else
917
  NVTE_CHECK(
918
919
      transformer_engine::cuda::cudart_version() >= 12020 &&
          transformer_engine::cuda::cudart_version() < 13000,
920
      "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
921
      transformer_engine::cuda::cudart_version());
922
923
  NVTE_CHECK(
      cublas_version() >= 120205 && cublas_version() < 130000,
924
      "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
925
      cublas_version());
926

927
928
929
930
931
932
933
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  const Tensor *inputCounter = convertNVTETensor(counter);
  Tensor *wspace = convertNVTETensor(workspace);
934

935
936
937
  const void *alpha_ptr = GetScalarOne();
  const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();

938
939
940
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
941
942
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
943
944
945
              alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
              gemm_producer, inputCounter, stream);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
946
}
947

948
949
950
951
952
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                              const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                              bool transa, bool transb, bool grad, NVTETensor *workspace,
                              bool accumulate, bool use_split_accumulator, int math_sm_count,
                              cudaStream_t stream) {
953
  using namespace transformer_engine;
954
955

  int num_streams = nvte_get_num_compute_streams();
956

957
  int num_stream_used = std::min(num_streams, num_gemms);
958
  // wait for current stream to finish
959
  NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
960
  for (int s = 0; s < num_stream_used; s++) {
961
962
    NVTE_CHECK_CUDA(
        cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
963
964
  }

965
  for (int i = 0; i < num_gemms; i++) {
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
    // Check whether GELU or dGELU epilogue is requested
    Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
    bool with_gelu_dgelu_epilogue =
        (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);

    // Construct config
    MatmulConfig config;
    if (grad) {
      config.dbias_tensor = bias[i];
      config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
    } else {
      config.bias_tensor = bias[i];
      config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
    }
    config.epilogue_aux_tensor = pre_gelu_out[i];
    config.use_split_accumulator = use_split_accumulator;
    config.sm_count = math_sm_count;

    // Launch GEMM
    const float alpha = 1.f;
    const float beta = accumulate ? 1.f : 0.f;
    nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
                        workspace[i % num_streams], &config,
                        detail::get_compute_stream(i % num_streams));
990
991
992
993
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
994
995
    NVTE_CHECK_CUDA(
        cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
996
997
998
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
999
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
1000
1001
  }
}
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;

  // Deprecation warning
  NVTE_WARN(
      "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
      "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
      "applicable).");

  multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
                           accumulate, use_split_accumulator, math_sm_count, stream);
}

1022
1023
1024
1025
1026
1027
1028
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                            const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                            bool transa, bool transb, bool grad, NVTETensor *workspace,
                            bool accumulate, bool use_split_accumulator, int math_sm_count,
                            cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_tensor_gemm);

  const int current_device = transformer_engine::cuda::current_device();
  const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
  const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
  const bool warn_fallback =
      transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);

  auto cublas_path = [&]() {
    multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  };

  // Currently only support cutlass group gemm on Hopper Arch
  if (!(is_hopper && use_cutlass)) {
    cublas_path();
    return;
  }

  auto is_empty_arr = [&](const NVTETensor *p) -> bool {
    if (p == nullptr) return true;
    for (int i = 0; i < num_gemms; ++i) {
      if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
    }
    return true;
  };

  auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
    int64_t ref_k = -1;
    for (size_t i = 0; i < num_gemms; i++) {
      const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
      const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];

      if ((k & 127) != 0) return false;

      if (ref_k < 0)
        ref_k = k;
      else if (k != ref_k)
        return false;
    }

    return true;
  };

  auto is_supported_dtype = [&]() -> bool {
    auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
    auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
    auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
    auto A_type = get_cuda_dtype(inputA->data.dtype);
    auto B_type = get_cuda_dtype(inputB->data.dtype);
    auto D_type = get_cuda_dtype(OutputD->data.dtype);

    return (A_type == B_type) && (A_type == D_type) &&
           ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
  };

  // CUTLASS Grouped GEMM fast path (SM90/TMA)
  // Conditions:
  //  - No fused epilogue: both bias and pre_gelu_out are empty.
  //  - Supported dtypes only: FP16/BF16 (FP32 accumulate).
  //  - Uniform K across groups and K % 128 == 0.
  //  - use_split_accumulator is ignored for FP16/BF16.
  //  - grad is irrelevant when bias/pre_gelu_out are empty.
  //
  // Otherwise, fall back to cuBLAS.
  if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
      all_groups_uniform_k128(B, transb)) {
    cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
                         current_device, math_sm_count, stream);
  } else {
    if (warn_fallback) {
      NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
    }
    cublas_path();
  }
}