cublaslt_gemm.cu 74.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
#include <transformer_engine/gemm.h>
17
#include <transformer_engine/multi_stream.h>
18
#include <transformer_engine/recipe.h>
19
20
#include <transformer_engine/transformer_engine.h>

21
#include <algorithm>
22
#include <cstdint>
23
#include <mutex>
24
#include <vector>
Tim Moon's avatar
Tim Moon committed
25

Przemek Tredak's avatar
Przemek Tredak committed
26
#include "../common.h"
27
#include "../util/cuda_runtime.h"
28
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
29
#include "../util/logging.h"
30
#include "../util/multi_stream.h"
31
#include "./config.h"
yuguo's avatar
yuguo committed
32
#ifndef __HIP_PLATFORM_AMD__
33
#include "./cutlass_grouped_gemm.cuh"
yuguo's avatar
yuguo committed
34
#endif
Przemek Tredak's avatar
Przemek Tredak committed
35

yuguo's avatar
yuguo committed
36
#ifndef __HIP_PLATFORM_AMD__
37
38
namespace {

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;

inline float *GetScalarOne() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float one = 1.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
  return dev_ptr;
}

inline float *GetScalarZero() {
  static std::once_flag init_flag;
  std::call_once(init_flag, []() {
    float zero = 0.0f;
    NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
  });
  // return address by cudaGetSymbolAddress
  float *dev_ptr;
  NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
  return dev_ptr;
}

__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }

70
71
72
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
73
  for (;; alignment /= 2) {
74
75
76
77
78
79
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

80
81
82
83
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

84
85
86
87
88
89
90
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
91
struct GemmParam {
92
93
94
95
96
97
98
99
100
101
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
102
103
};

104
105
106
107
108
109
110
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
111
112
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
113
                                int m, int n, int k) {
114
  using namespace transformer_engine;
115
116
117
118
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
119
120
      "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
          to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
121
122
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
123
  GemmParam ret;
124

125
  // Transpose mode with column-major ordering
126
127
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
128

129
130
131
132
  // Set conditions for MXFP8 and NVFP4 gemm execution.
  const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
  const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);

133
  // Configure A matrix
134
  if (is_tensor_scaling(A.scaling_mode)) {
135
    // Unscaled or FP8 tensor scaling
136
    ret.A = A.data.dptr;
137
138
    ret.transA = transA;
    ret.Atype = A.data.dtype;
139
    ret.A_scale_inv = A.scale_inv.dptr;
140
    ret.lda = is_A_transposed ? k : m;
141
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
142
143
144
145
146
147
148
149
150
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
151
152
      }
    }
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
  } else if (nvfp4) {
    // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.

    if (is_A_transposed) {
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
      NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
                 "Input A has unsupported combination of recipe and layout");
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
    }
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
    ret.transA = CUBLAS_OP_T;  // NVFP4 gemm is only supported in TN layout.
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = k;
  } else if (mxfp8) {
    // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
170
171
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
172

173
    if (is_A_transposed) {
174
175
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
176
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
177
    }
178
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
179
    ret.transA = transA;
180
181
182
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
183
184
185
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
186
    if (is_A_transposed) {
187
188
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
189
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
190
    }
191
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
192
    ret.transA = CUBLAS_OP_T;
193
194
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
               "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
211
    ret.B = B.data.dptr;
212
213
    ret.transB = transB;
    ret.Btype = B.data.dtype;
214
    ret.B_scale_inv = B.scale_inv.dptr;
215
    ret.ldb = is_B_transposed ? n : k;
216
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
217
218
219
220
221
222
223
224
225
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
226
      }
227
    }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
  } else if (nvfp4) {
    if (is_B_transposed) {
      NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
                 "Input B has unsupported combination of recipe and layout");
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
    } else {
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
    ret.transB = CUBLAS_OP_N;  // NVFP4 gemm is only supported in TN layout.
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = k;
  } else if (mxfp8) {
242
    if (is_B_transposed) {
243
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
244
    } else {
245
246
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
247
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
248
    ret.transB = transB;
249
250
251
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
252
253
254
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
255
    if (is_B_transposed) {
256
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
257
    } else {
258
259
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
260
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
261
    ret.transB = CUBLAS_OP_N;
262
263
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
264
265
266
267
268
269
270
271
272
273
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
274
275
    }
  } else {
276
    NVTE_ERROR("B has unsupported scaling mode");
277
  }
278

279
280
281
  return ret;
}

282
283
284
285
286
287
288
/* cuBLAS version number at run-time */
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

289
}  // namespace
yuguo's avatar
yuguo committed
290
#endif // __HIP_PLATFORM_AMD__
291

Przemek Tredak's avatar
Przemek Tredak committed
292
namespace transformer_engine {
yuguo's avatar
yuguo committed
293
294
295
296
297
298
299
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
yuguo's avatar
yuguo committed
300
                 const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
yuguo's avatar
yuguo committed
301
#else // Use cublasLt
302
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
303
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
304
305
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
306
                 const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
Jan Bielak's avatar
Jan Bielak committed
307
308
                 int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
                 cudaStream_t stream) {
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

324
325
326
327
328
329
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

330
331
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

332
  void *C = outputD->data.dptr;
333
  void *D = outputD->data.dptr;
334
335
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
336
337
338
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
339
340
341
342
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
343
  const bool gelu = pre_gelu_out != nullptr;
344
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
  const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);

  // Update scaling factors with NVFP4 tensor scales
  // TODO: Check whether scales are on CPU/GPU or add API to control.
  // Currently scales are assumed to be on CPU when amax is provided
  // and on GPU when not provided, but this is brittle.
  if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
    // Reserve some workspace for alpha scale
    NVTE_CHECK(workspaceSize >= 4,
               "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
               workspaceSize, " bytes remaining.");
    workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
    uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
    float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);

    // Update alpha scale on device
    // Note: Compute NVFP4 tensor scales based on amaxes and then
    // divide from alpha scale. This way we only need to apply NVFP4
    // tensor scales in matmul output, instead of in matmul inputs.
    float old_alpha = *reinterpret_cast<const float *>(alpha);  // Assumed to be on CPU
    TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
    nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
                                        old_alpha, new_alpha_tensor.data(), stream);
    alpha = new_alpha_ptr;

    // Make sure beta scale is on device
    float old_beta = *reinterpret_cast<const float *>(beta);  // Assumed to be on CPU
    if (old_beta == 0) {
      beta = GetScalarZero();  // Device constant memory
    } else if (old_beta == 1) {
      beta = GetScalarOne();  // Device constant memory
    } else {
      // Move beta to workspace
      NVTE_CHECK(workspaceSize >= 4,
                 "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
                 workspaceSize, " bytes remaining.");
      workspaceSize = (workspaceSize / 4) * 4 - 4;  // Remove last 4 aligned bytes
      float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
      set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
      NVTE_CHECK_CUDA(cudaGetLastError());
      beta = new_beta_ptr;
    }
  }
388
389
390

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
391
392
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
393

394
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
395
             "FP8 input to GEMM requires inverse of scale!");
396
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
397
             "FP8 input to GEMM requires inverse of scale!");
398
399
400
401
  NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
             "FP4 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
402

403
404
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
405
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
406
  if ((use_fp8 || use_fp4) && gelu) {
407
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
408
               "fp8 Aux output for gemm + gelu fusion not supported!");
409
  }
410
411
412
413
414
  if (is_fp4_dtype(outputD->data.dtype)) {
    NVTE_ERROR("FP4 GEMM output is not supported!");
  }
  if (use_fp4 && (D_type == CUDA_R_16F)) {
    NVTE_ERROR("FP4 GEMM does not support FP16 output!");
415
  }
Przemek Tredak's avatar
Przemek Tredak committed
416

417
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
418

419
420
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
421
  cublasLtMatmulPreference_t preference = nullptr;
422
  int returnedResults = 0;
423
424
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
425

426
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
427

428
429
430
431
432
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
433

434
  // Create matrix descriptors. Not setting any extra attributes.
435
436
437
438
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
439

440
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
441

442
443
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
444
                                                   &param.transA, sizeof(param.transA)));
445
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
446
                                                   &param.transB, sizeof(param.transB)));
447
448
  // Set math SM count
  if (math_sm_count != 0) {
449
450
451
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
452
453
  }

454
455
  // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
  // as appropriate. Note: gelu fusion isn't available right now, and we don't need
456
  // amax(D) either (next op is high precision).
457
458
459
460
461
  const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);

  if (use_fp8 || use_fp4) {
    // Fast accumulation is only supported for FP8.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
462
463
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
464
465

    // Scaling factors.
466
#if CUBLAS_VERSION >= 120800
467
468
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
469
#endif  // CUBLAS_VERSION >= 120800
470
    if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
471
472
473
474
475
476
477
478
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
479
#if CUBLAS_VERSION >= 120800
480
481
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
482
#endif  // CUBLAS_VERSION >= 120800
483
    } else if (mxfp8_gemm) {
484
485
486
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
487
488
489
490
491
492
493
494
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
495
496
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
497
498
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
499
      if (cublas_version() <= 120803) {
500
501
502
503
504
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
505
506
507
#else
      NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
#endif                     // CUBLAS_VERSION >= 120800
    } else if (use_fp4) {  // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
      // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
      cublasDataType_t scale_type = CUDA_R_32F;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));

      // Set pointer mode: alpha and beta are both device pointers
      // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
      cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));

      fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
      fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
      NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
536
#endif  // CUBLAS_VERSION >= 120800
537
538
539
540
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
541
542
543
544
#if CUBLAS_VERSION >= 120900
      NVTE_CHECK(cublas_version() >= 120900,
                 "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
                 cublas_version());
545
546
547
548
549
550
551
552
553
554
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
555
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
556
557
558
559
560
561
562
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
563
564
565
      NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120900
566
567
568
569
570
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

571
572
573
574
575
576
577
578
579
580
#if CUBLAS_VERSION >= 120800
    if (cublas_version() >= 120800) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                                       &scaling_mode_b, sizeof(scaling_mode_b)));
    }
#endif  // CUBLAS_VERSION >= 120800
581
582
583
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
584
585
586
587
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
588
589
590
591
592
593
594
595
596
#if CUBLAS_VERSION >= 120800
      if (cublas_version() >= 120800) {
        // NOTE: In all current cases where FP8 output is supported, the input is
        // scaled identically to the output.
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                         CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                         &scaling_mode_a, sizeof(scaling_mode_a)));
      }
#endif  // CUBLAS_VERSION >= 120800
597
598
599
600
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
601
602
603
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
604
    if (bias) {
605
606
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
607
    }
608
609
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
610
  }
Przemek Tredak's avatar
Przemek Tredak committed
611

612
613
614
615
616
617
618
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
619
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
620
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
621
622
623
624
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
625
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
626
627
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
628
629
630
631
632
633
634
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
635
636
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
637
638
639
640
641
642
643
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
644
645
646
647
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
648
649
650
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
651
  }
Przemek Tredak's avatar
Przemek Tredak committed
652

653
654
655
656
657
658
659
660
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

661
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
662
                                                   &epilogue, sizeof(epilogue)));
663

664
  if (counter != nullptr) {
665
666
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
    NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
667
               CUDA_VERSION);
668
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
669
    NVTE_ERROR(
670
        "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
671
        CUBLAS_VERSION);
672
#else
673
    NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
674
               "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
675
676
               cuda::cudart_version());
    NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
677
               "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
678
               cublas_version());
679
680
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
681
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
682
683
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
684
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
685
686
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
687
688
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
689
690
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
691
692
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
693
694
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
695
696
    }
#endif
697
  }
Przemek Tredak's avatar
Przemek Tredak committed
698

699
700
701
702
703
704
705
706
  // align the workspace to 256 B
  const int required_alignment = 256;
  const auto original_workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
  uint8_t *aligned_workspace_ptr =
      reinterpret_cast<uint8_t *>(workspace) + required_alignment - original_workspace_alignment;
  workspaceSize = workspaceSize - required_alignment + original_workspace_alignment;
  const auto new_workspace_alignment =
      _getAlignment(reinterpret_cast<uintptr_t>(aligned_workspace_ptr));
707
708
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
709
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
710
711
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
712
713
714
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
715
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
716
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
717
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
718
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
719
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
720
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
721
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
722
723
724
  NVTE_CHECK(new_workspace_alignment % 256 == 0,
             "cuBLAS workspace pointer must be aligned to 256 bytes, got ",
             new_workspace_alignment);
Przemek Tredak's avatar
Przemek Tredak committed
725

726
727
728
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
729
730
731
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
732
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
733

734
  // D = alpha * (A * B) + beta * C
735
736
737
738
739
740
741
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
                                   param.A,                      /* A */
                                   Adesc, param.B,               /* B */
                                   Bdesc, beta,                  /* beta */
                                   C,                            /* C */
                                   Cdesc, D,                     /* D */
                                   Ddesc, &heuristicResult.algo, /* algo */
742
                                   aligned_workspace_ptr,        /* workspace */
743
                                   workspaceSize, stream));      /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
744

745
  // Update FP8 scale-inv in output tensor
746
747
748
749
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
750
751
752
    update_tensor_scale_inv(outputD, stream);
  }

753
754
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
755
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
756
757
758
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
759
}
yuguo's avatar
yuguo committed
760
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
761

yuguo's avatar
yuguo committed
762
763
764
765
766
767
768
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
yuguo's avatar
yuguo committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
  int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
  unsigned int cuMask[4];
  unsigned int cuMaskSize = 4;
  if (comm_cu_nums == 4) {
    cuMask[0] = 0xfffffff0;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 8) {
    cuMask[0] = 0xffffff00;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 16) {
    cuMask[0] = 0xffff0000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 32) {
    cuMask[0] = 0x00000000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else {
    NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
  }
  const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
yuguo's avatar
yuguo committed
796
  for (int i = 0; i < num_batchgemm_streams; i++) {
yuguo's avatar
yuguo committed
797
798
799
800
801
802
803
#ifdef __HIP_PLATFORM_AMD__    
    if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
      NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams_batchgemm[i], cuMaskSize, cuMask));
    } else {
      NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    }
#else
yuguo's avatar
yuguo committed
804
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
yuguo's avatar
yuguo committed
805
#endif
yuguo's avatar
yuguo committed
806
807
808
809
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

810
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
811

812
813
814
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
815
                      int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
816
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
817
  using namespace transformer_engine;
818
819

  // Tensors
820
821
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
822
  Tensor *outputD = convertNVTETensorCheck(D);
823
824
825
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
826

827
828
829
830
831
832
833
834
835
  // Scales
  const float alpha = 1;
  const float beta = accumulate ? 1 : 0;

  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }
836
#ifdef __HIP_PLATFORM_AMD__
837
838
839
840
841
842
843
844
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

862
863
864
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);

yuguo's avatar
yuguo committed
865
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
866
867
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
868
869
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && use_split_accumulator) nvte_use_hipblaslt = 1;           
870
871
872
873
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
874
  } else {
yuguo's avatar
yuguo committed
875
    hipblas_gemm(inputA,
yuguo's avatar
yuguo committed
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
yuguo's avatar
yuguo committed
893
  }
894
#else 
895
896
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
897
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
898
#endif
899
900
901
902
}

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
                         const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
903
                         NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
  NVTE_API_CALL(nvte_cublas_gemm_v2);
  using namespace transformer_engine;

  // Data tensors
  const Tensor *A_tensor = convertNVTETensorCheck(A);
  const Tensor *B_tensor = convertNVTETensorCheck(B);
  const Tensor *C_tensor = convertNVTETensorCheck(C);
  Tensor *D_tensor = convertNVTETensorCheck(D);
  NVTE_CHECK(C_tensor == D_tensor,
             "Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");

  // Workspace
  void *workspace_ptr = nullptr;
  size_t workspace_size = 0;
  Tensor *workspace_tensor = convertNVTETensor(workspace);
  if (workspace_tensor != nullptr) {
    workspace_ptr = workspace_tensor->data.dptr;
    workspace_size =
        get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
  }

  // Additional config
  MatmulConfig config_;
  if (config != nullptr) {
    config_ = *reinterpret_cast<MatmulConfig *>(config);
  }

  // Configure GEMM epilogue
  const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
  if (with_grad_epilogue) {
    NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
               "Invalid epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ").");
  }
  Tensor dummy_tensor;
  Tensor *epilogue_bias_tensor = &dummy_tensor;
  if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
  } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
    epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
  }
  Tensor *epilogue_aux_tensor = &dummy_tensor;
  if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
    NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
               "Requested epilogue (bias=", config_.bias_tensor != nullptr,
               ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
               ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
    epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
  }
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
#ifdef __HIP_PLATFORM_AMD__
  NVTE_CHECK(*alpha == 1.0f, "alpha must be 1.0 for hip");
  NVTE_CHECK(*beta == 1.0f || *beta == 0.0f, "beta must be 1.0 or 0.0 for hip");
  bool accumulate = false;
  if (*alpha == 1.0f and *beta == 1.0f) {
    accumulate = true;
  }
  const size_t A0 = A_tensor->flat_first_dim();
  const size_t A1 = A_tensor->flat_last_dim();
  const size_t B0 = B_tensor->flat_first_dim();
  const size_t B1 = B_tensor->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
985

986
987
988
989
990
991
992
  const bool use_int8 = is_int8_dtype(A_tensor->data.dtype) ||
                        is_int8_dtype(B_tensor->data.dtype);

  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
  const bool use_fp8 = is_fp8_dtype(A_tensor->data.dtype) ||
                       is_fp8_dtype(B_tensor->data.dtype);
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
yuguo's avatar
yuguo committed
993
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && config_.use_split_accumulator) nvte_use_hipblaslt = 1;           
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
  if ((epilogue_bias_tensor->data.dptr != nullptr) || (epilogue_aux_tensor->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, m, n, k, lda, ldb, ldd, transa, transb, with_grad_epilogue,
                workspace_ptr, workspace_size, accumulate, config_.use_split_accumulator, config_.sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
  } else {
    hipblas_gemm(A_tensor,
                 B_tensor,
                 D_tensor,
                 epilogue_bias_tensor,
                 epilogue_aux_tensor,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 with_grad_epilogue, workspace_ptr,
                 workspace_size,
                 accumulate, config_.use_split_accumulator,
                 config_.sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
  }
yuguo's avatar
yuguo committed
1018
#else 
1019
1020
1021
1022
1023
  // Launch GEMM
  cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
              transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
              with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
              config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
yuguo's avatar
yuguo committed
1024
#endif
Jan Bielak's avatar
Jan Bielak committed
1025
1026
1027
1028
1029
}

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
yuguo's avatar
yuguo committed
1030
                             bool use_split_accumulator, int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
Jan Bielak's avatar
Jan Bielak committed
1031
1032
  NVTE_API_CALL(nvte_cublas_gemm_scaled);
  using namespace transformer_engine;
1033
1034

  // Tensors
Jan Bielak's avatar
Jan Bielak committed
1035
1036
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
1037
  Tensor *outputD = convertNVTETensorCheck(D);
Jan Bielak's avatar
Jan Bielak committed
1038
1039
1040
1041
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);

1042
1043
1044
1045
1046
  // Check for NVFP4
  // TODO Remove once alpha scale logic is moved into cublas_gemm function
  if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
    NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
  }
yuguo's avatar
yuguo committed
1047
#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
  NVTE_CHECK(alpha == 1.0f, "alpha must be 1.0 for hip");
  NVTE_CHECK(beta == 1.0f || beta == 0.0f, "beta must be 1.0 or 0.0 for hip");
  bool accumulate = false;
  if (alpha == 1.0f and beta == 1.0f) {
    accumulate = true;
  }

  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);

  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && use_split_accumulator) nvte_use_hipblaslt = 1;           
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
  }
#else 
Jan Bielak's avatar
Jan Bielak committed
1113
1114
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1115
              &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
yuguo's avatar
yuguo committed
1116
#endif
1117
1118
}

1119
1120
1121
1122
1123
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
yuguo's avatar
yuguo committed
1124
                             cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
1125
  NVTE_API_CALL(nvte_cublas_atomic_gemm);
1126
  using namespace transformer_engine;
1127

yuguo's avatar
yuguo committed
1128
#ifndef __HIP_PLATFORM_AMD__
1129
  // Check CUDA and cuBLAS versions
1130
1131
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
  NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
1132
             CUDA_VERSION);
1133
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
1134
1135
1136
  NVTE_ERROR(
      "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
      CUBLAS_VERSION);
1137
#else
1138
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
1139
  NVTE_CHECK(
1140
1141
      transformer_engine::cuda::cudart_version() >= 12020 &&
          transformer_engine::cuda::cudart_version() < 13000,
1142
      "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
1143
      transformer_engine::cuda::cudart_version());
1144
1145
  NVTE_CHECK(
      cublas_version() >= 120205 && cublas_version() < 130000,
1146
      "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
1147
      cublas_version());
yuguo's avatar
yuguo committed
1148
#endif
1149
1150
1151
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
#endif // __HIP_PLATFORM_AMD__
1152

1153
#ifdef NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
1154
1155
1156
1157
1158
1159
1160
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  const Tensor *inputCounter = convertNVTETensor(counter);
  Tensor *wspace = convertNVTETensor(workspace);
1161

1162

1163
1164
1165
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
1166
#ifdef __HIP_PLATFORM_AMD__
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

wenjh's avatar
wenjh committed
1187
1188
1189
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);
  
yuguo's avatar
yuguo committed
1190
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
1191
1192
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
1193
1194
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && use_split_accumulator) nvte_use_hipblaslt = 1;           
1195
1196
1197
1198
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 
                m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 m_split,
                 n_split,
                 gemm_producer,
                 inputCounter,
                 stream);
yuguo's avatar
yuguo committed
1218
  }
1219
#else 
yuguo's avatar
yuguo committed
1220
1221
    const void *alpha_ptr = GetScalarOne();
    const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();
1222
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
1223
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1224
1225
1226
              alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
              gemm_producer, inputCounter, stream);
#endif
1227
#endif // NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
yuguo's avatar
yuguo committed
1228
1229
}

1230
1231
1232
1233
1234
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                              const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                              bool transa, bool transb, bool grad, NVTETensor *workspace,
                              bool accumulate, bool use_split_accumulator, int math_sm_count,
                              cudaStream_t stream) {
1235
  using namespace transformer_engine;
1236
1237

  int num_streams = nvte_get_num_compute_streams();
1238

1239
  int num_stream_used = std::min(num_streams, num_gemms);
1240
  // wait for current stream to finish
1241
  NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
1242
  for (int s = 0; s < num_stream_used; s++) {
1243
1244
    NVTE_CHECK_CUDA(
        cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
1245
  }
yuguo's avatar
yuguo committed
1246
  const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
yuguo's avatar
yuguo committed
1247
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
1248
1249
1250
1251
1252
  bool NVTE_FORCE_BLAS_MULSTREAM;
  if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
    NVTE_FORCE_BLAS_MULSTREAM = true;
    if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'))
      NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
yuguo's avatar
yuguo committed
1253
  } else{
yuguo's avatar
yuguo committed
1254
    NVTE_FORCE_BLAS_MULSTREAM = false;
yuguo's avatar
yuguo committed
1255
  }
yuguo's avatar
yuguo committed
1256
  if (NVTE_FORCE_BLAS_MULSTREAM){
yuguo's avatar
yuguo committed
1257
    for (int i = 0; i < num_gemms; i++) {
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
      // Check whether GELU or dGELU epilogue is requested
      Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
      bool with_gelu_dgelu_epilogue =
          (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);

      // Construct config
      MatmulConfig config;
      if (grad) {
        config.dbias_tensor = bias[i];
        config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
      } else {
        config.bias_tensor = bias[i];
        config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
      }
      config.epilogue_aux_tensor = pre_gelu_out[i];
      config.use_split_accumulator = use_split_accumulator;
      config.sm_count = math_sm_count;

      // Launch GEMM
      const float alpha = 1.f;
      const float beta = accumulate ? 1.f : 0.f;
      nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
                          workspace[i % num_streams], &config,
                          detail::get_compute_stream(i % num_streams));
yuguo's avatar
yuguo committed
1282
1283
1284
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
      // Check whether GELU or dGELU epilogue is requested
      Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
      bool with_gelu_dgelu_epilogue =
          (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);

      // Construct config
      MatmulConfig config;
      if (grad) {
        config.dbias_tensor = bias[i];
        config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
      } else {
        config.bias_tensor = bias[i];
        config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
      }
      config.epilogue_aux_tensor = pre_gelu_out[i];
      config.use_split_accumulator = use_split_accumulator;
      config.sm_count = math_sm_count;

      // Launch GEMM
      const float alpha = 1.f;
      const float beta = accumulate ? 1.f : 0.f;
      nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
                          workspace[i % num_streams], &config,
                          detail::get_compute_stream(i % num_streams), 1, 0, i % num_streams);
yuguo's avatar
yuguo committed
1309
    }
1310
1311
1312
1313
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
1314
1315
    NVTE_CHECK_CUDA(
        cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
1316
1317
1318
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
1319
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
1320
1321
  }
}
yuguo's avatar
yuguo committed
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;

  // Deprecation warning
  NVTE_WARN(
      "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
      "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
      "applicable).");

  multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
                           accumulate, use_split_accumulator, math_sm_count, stream);
}

1342
#ifndef __HIP_PLATFORM_AMD__
1343
1344
1345
1346
1347
1348
1349
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
1350
1351
#endif

yuguo's avatar
yuguo committed
1352
#ifdef __HIP_PLATFORM_AMD__
1353
1354
1355
1356
1357
1358
1359
void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  using namespace transformer_engine;
wenjh's avatar
wenjh committed
1360
  if(num_gemms == 0) { return; }
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397

  std::vector<const Tensor*> inputA;
  std::vector<const Tensor*> inputB;
  std::vector<Tensor*> outputD;
  std::vector<const Tensor*> biasTensor;
  std::vector<Tensor*> outputGelu;
  std::vector<int64_t> m;
  std::vector<int64_t> n;
  std::vector<int64_t> k;
  std::vector<int64_t> b;
  
  for (int i = 0; i < num_gemms; i++) {
    inputA.push_back(convertNVTETensorCheck(A[i]));
    inputB.push_back(convertNVTETensorCheck(B[i]));
    outputD.push_back(convertNVTETensorCheck(D[i]));
    biasTensor.push_back(convertNVTETensorCheck(bias[i]));
    outputGelu.push_back(convertNVTETensorCheck(pre_gelu_out[i]));
    b.push_back(1);

    size_t A0 = inputA[i]->flat_first_dim();
    size_t A1 = inputA[i]->flat_last_dim();
    size_t B0 = inputB[i]->flat_first_dim();
    size_t B1 = inputB[i]->flat_last_dim();
  
    if (transa) {
      m.push_back(A0);
      k.push_back(A1);
    } else {
      m.push_back(A1);
      k.push_back(A0);
    }
    if (transb) {
      n.push_back(B1);
    } else {
      n.push_back(B0);
    }
  }
wenjh's avatar
wenjh committed
1398
  bool use_bias = biasTensor[0]->data.dptr != nullptr? true: false;
1399
1400
  Tensor *wspace = convertNVTETensorCheck(workspace[0]);
  
wenjh's avatar
wenjh committed
1401
1402
  if (outputGelu[0]->data.dptr != nullptr) {
    NVTE_ERROR("MOE nvte_grouped_gemm not surpport gelu.");
1403
1404
  }

wenjh's avatar
wenjh committed
1405
  hipblaslt_groupedgemm(inputA, inputB, outputD, biasTensor, use_bias, grad, m, n, k, b,
1406
1407
1408
1409
1410
1411
1412
                      (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                      (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, 
                      wspace->data.dptr, wspace->data.shape[0],
                      accumulate, use_split_accumulator, 
                      math_sm_count, stream);

}
yuguo's avatar
yuguo committed
1413

yuguo's avatar
yuguo committed
1414
1415
1416
1417
1418
1419
1420
1421
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
1422
  int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);
yuguo's avatar
yuguo committed
1423
1424
1425
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

yuguo's avatar
yuguo committed
1426
  int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
yuguo's avatar
yuguo committed
1427
1428
1429
1430
1431
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
yuguo's avatar
yuguo committed
1432
  for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
1454
1455
1456
1457
1458
1459
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
yuguo's avatar
yuguo committed
1460
1461
1462
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }
yuguo's avatar
yuguo committed
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
yuguo's avatar
yuguo committed
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
  hipblas_batchgemm(inputA,
            inputB,
            outputD,
            biasTensor,
            outputGelu,
            m, n, k,
            lda, ldb, ldd,
            (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            grad, wspace->data.dptr,
            wspace->data.shape[0],
            accumulate, use_split_accumulator,
            math_sm_count,
            0,
            0,
            false,
            nullptr,
            batch_count,
            stream);
}

yuguo's avatar
yuguo committed
1517
1518

// add for batchgemm
yuguo's avatar
yuguo committed
1519
void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
yuguo's avatar
yuguo committed
1520
1521
1522
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
yuguo's avatar
yuguo committed
1523
  NVTE_API_CALL(nvte_cublas_batchgemm_tensorwise_int8);
yuguo's avatar
yuguo committed
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
  using namespace transformer_engine;
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  const Tensor *inputA_scales = convertNVTETensorCheck(A_scales);
  const Tensor *inputB_scales = convertNVTETensorCheck(B_scales);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
1570
  NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now.");
yuguo's avatar
yuguo committed
1571
1572

}
wenjh's avatar
wenjh committed
1573
#endif
1574
1575
1576
1577
1578
1579
1580

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                            const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
                            bool transa, bool transb, bool grad, NVTETensor *workspace,
                            bool accumulate, bool use_split_accumulator, int math_sm_count,
                            cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_tensor_gemm);
wenjh's avatar
wenjh committed
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
#ifdef __HIP_PLATFORM_AMD__
  const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM");
  if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){
      nvte_grouped_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  } else {
      multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  }
#else
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
  const int current_device = transformer_engine::cuda::current_device();
  const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
  const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
  const bool warn_fallback =
      transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);

  auto cublas_path = [&]() {
    multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
                             workspace, accumulate, use_split_accumulator, math_sm_count, stream);
  };

  // Currently only support cutlass group gemm on Hopper Arch
  if (!(is_hopper && use_cutlass)) {
    cublas_path();
    return;
  }

  auto is_empty_arr = [&](const NVTETensor *p) -> bool {
    if (p == nullptr) return true;
    for (int i = 0; i < num_gemms; ++i) {
      if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
    }
    return true;
  };

  auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
    int64_t ref_k = -1;
    for (size_t i = 0; i < num_gemms; i++) {
      const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
      const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];

      if ((k & 127) != 0) return false;

      if (ref_k < 0)
        ref_k = k;
      else if (k != ref_k)
        return false;
    }

    return true;
  };

  auto is_supported_dtype = [&]() -> bool {
    auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
    auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
    auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
    auto A_type = get_cuda_dtype(inputA->data.dtype);
    auto B_type = get_cuda_dtype(inputB->data.dtype);
    auto D_type = get_cuda_dtype(OutputD->data.dtype);

    return (A_type == B_type) && (A_type == D_type) &&
           ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
  };

  // CUTLASS Grouped GEMM fast path (SM90/TMA)
  // Conditions:
  //  - No fused epilogue: both bias and pre_gelu_out are empty.
  //  - Supported dtypes only: FP16/BF16 (FP32 accumulate).
  //  - Uniform K across groups and K % 128 == 0.
  //  - use_split_accumulator is ignored for FP16/BF16.
  //  - grad is irrelevant when bias/pre_gelu_out are empty.
  //
  // Otherwise, fall back to cuBLAS.
  if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
      all_groups_uniform_k128(B, transb)) {
    cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
                         current_device, math_sm_count, stream);
  } else {
    if (warn_fallback) {
      NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
    }
    cublas_path();
  }
wenjh's avatar
wenjh committed
1664
#endif
1665
}
wenjh's avatar
wenjh committed
1666
1667