cublaslt_gemm.cu 54.9 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
#include <transformer_engine/gemm.h>
17
#include <transformer_engine/multi_stream.h>
18
19
#include <transformer_engine/transformer_engine.h>

20
#include <cstdint>
21
#include <mutex>
Tim Moon's avatar
Tim Moon committed
22

Przemek Tredak's avatar
Przemek Tredak committed
23
#include "../common.h"
24
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
25
#include "../util/logging.h"
26
#include "../util/multi_stream.h"
27
#include "common/util/cuda_runtime.h"
Przemek Tredak's avatar
Przemek Tredak committed
28

yuguo's avatar
yuguo committed
29
#ifndef __HIP_PLATFORM_AMD__
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

50
51
52
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
53
  for (;; alignment /= 2) {
54
55
56
57
58
59
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

60
61
62
63
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

64
65
66
67
68
69
70
/* Parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
71
struct GemmParam {
72
73
74
75
76
77
78
79
80
81
  void *A = nullptr;
  void *B = nullptr;
  cublasOperation_t transA = CUBLAS_OP_N;
  cublasOperation_t transB = CUBLAS_OP_N;
  transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
  transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
  void *A_scale_inv = nullptr;
  void *B_scale_inv = nullptr;
  int lda = 0;  // A column strides
  int ldb = 0;  // B column strides
82
83
};

84
85
86
87
88
89
90
/* Populate parameters for cuBLAS GEMM
 *
 * cuBLAS follows the BLAS convention of column-major ordering. This
 * is different than the row-major that is typically used in
 * Transformer Engine.
 *
 */
91
92
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
93
                                int m, int n, int k) {
94
  using namespace transformer_engine;
95
96
97
98
  NVTE_CHECK(
      A.scaling_mode == B.scaling_mode ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
          (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
99
100
      "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
          to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
101
102
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
103
  GemmParam ret;
104

105
  // Transpose mode with column-major ordering
106
107
  bool is_A_transposed = transA == CUBLAS_OP_T;
  bool is_B_transposed = transB == CUBLAS_OP_T;
108

109
  // Configure A matrix
110
  if (is_tensor_scaling(A.scaling_mode)) {
111
    // Unscaled or FP8 tensor scaling
112
    ret.A = A.data.dptr;
113
114
    ret.transA = transA;
    ret.Atype = A.data.dtype;
115
    ret.A_scale_inv = A.scale_inv.dptr;
116
    ret.lda = is_A_transposed ? k : m;
117
    if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
118
119
120
121
122
123
124
125
126
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
        ret.A = A.columnwise_data.dptr;
        ret.transA = CUBLAS_OP_T;
        ret.Atype = A.columnwise_data.dtype;
        ret.A_scale_inv = A.columnwise_scale_inv.dptr;
        ret.lda = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
127
128
      }
    }
129
130
131
132
  } else if (is_mxfp_scaling(A.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
133
    if (is_A_transposed) {
134
135
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
136
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
137
    }
138
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
139
    ret.transA = transA;
140
141
142
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.lda = is_A_transposed ? k : m;
143
144
145
  } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
146
    if (is_A_transposed) {
147
148
      NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
    } else {
149
      NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
150
    }
151
    ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
152
    ret.transA = CUBLAS_OP_T;
153
154
    ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    ret.lda = k;

    // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.lda % 16) == 0,
               "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
    // Smallest supported CType is 2 bytes in this scaling mode.
    NVTE_CHECK((m % 8) == 0,
               "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
  } else {
    NVTE_ERROR("A has unsupported scaling mode");
  }

  // Configure B matrix
  if (is_tensor_scaling(B.scaling_mode)) {
    // Unscaled or FP8 tensor scaling
171
    ret.B = B.data.dptr;
172
173
    ret.transB = transB;
    ret.Btype = B.data.dtype;
174
    ret.B_scale_inv = B.scale_inv.dptr;
175
    ret.ldb = is_B_transposed ? n : k;
176
    if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
177
178
179
180
181
182
183
184
185
      // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
      if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
        ret.B = B.columnwise_data.dptr;
        ret.transB = CUBLAS_OP_N;
        ret.Btype = B.columnwise_data.dtype;
        ret.B_scale_inv = B.columnwise_scale_inv.dptr;
        ret.ldb = k;
      } else {
        NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
186
      }
187
188
189
190
191
    }
  } else if (is_mxfp_scaling(B.scaling_mode)) {
    // MXFP8
    // Note: Row-wise and column-wise data are scaled along different
    // dimensions (with matrix interpreted in row-major order).
192
    if (is_B_transposed) {
193
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
194
    } else {
195
196
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
197
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
198
    ret.transB = transB;
199
200
201
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
    ret.ldb = is_B_transposed ? n : k;
202
203
204
  } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
    // FP8 block scaling
    // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
205
    if (is_B_transposed) {
206
      NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
207
    } else {
208
209
      NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
    }
210
    ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
211
    ret.transB = CUBLAS_OP_N;
212
213
    ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
214
215
216
217
218
219
220
221
222
223
    ret.ldb = k;

    // Requirements from
    // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
    NVTE_CHECK((ret.ldb % 16) == 0,
               "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
    if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
      // Observed this requirement only present for B tensor is 1D quantized.
      NVTE_CHECK((n % 8) == 0,
                 "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
224
225
    }
  } else {
226
    NVTE_ERROR("B has unsupported scaling mode");
227
  }
228

229
230
231
  return ret;
}

232
233
234
235
236
237
238
/* cuBLAS version number at run-time */
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

239
}  // namespace
yuguo's avatar
yuguo committed
240
#endif // __HIP_PLATFORM_AMD__
241

Przemek Tredak's avatar
Przemek Tredak committed
242
namespace transformer_engine {
yuguo's avatar
yuguo committed
243
244
245
246
247
248
249
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
yuguo's avatar
yuguo committed
250
                 const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
yuguo's avatar
yuguo committed
251
#else // Use cublasLt
252
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
253
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
254
255
                 const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
                 cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
Jan Bielak's avatar
Jan Bielak committed
256
257
258
                 float alpha, float beta, bool use_split_accumulator, int math_sm_count,
                 int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
                 cudaStream_t stream) {
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  // Tensor dims in row-major order
  const int A0 = inputA->flat_first_dim();
  const int A1 = inputA->flat_last_dim();
  const int B0 = inputB->flat_first_dim();
  const int B1 = inputB->flat_last_dim();

  // GEMM dims in column-major order
  const int m = transa == CUBLAS_OP_T ? A0 : A1;
  const int n = transb == CUBLAS_OP_T ? B1 : B0;
  const int k = transa == CUBLAS_OP_T ? A1 : A0;
  NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
             "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
             ")");
  const int ldd = m;

274
275
276
277
278
279
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

280
281
  const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

282
  void *C = outputD->data.dptr;
283
  void *D = outputD->data.dptr;
284
285
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
286
287
288
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
289
290
291
292
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
293
  const bool gelu = pre_gelu_out != nullptr;
294
295
296
297
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
298
299
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
300

301
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
302
             "FP8 input to GEMM requires inverse of scale!");
303
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
304
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
305

306
307
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
308
309
310
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
311
               "fp8 Aux output for gemm + gelu fusion not supported!");
312
  }
313
  if (is_fp8_dtype(outputD->data.dtype)) {
Jan Bielak's avatar
Jan Bielak committed
314
    NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!");
315
  }
Przemek Tredak's avatar
Przemek Tredak committed
316

317
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
318

319
320
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
321
  cublasLtMatmulPreference_t preference = nullptr;
322
  int returnedResults = 0;
323
324
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
325

326
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
327

328
329
330
331
332
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
333

334
  // Create matrix descriptors. Not setting any extra attributes.
335
336
337
338
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
339

340
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
341

342
343
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
344
                                                   &param.transA, sizeof(param.transA)));
345
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
346
                                                   &param.transB, sizeof(param.transB)));
347
348
  // Set math SM count
  if (math_sm_count != 0) {
349
350
351
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
352
353
  }

354
355
356
357
358
359
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
360
361
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
362
363

    // Scaling factors.
364
#if CUBLAS_VERSION >= 120800
365
366
    cublasLtMatmulMatrixScale_t scaling_mode_a;
    cublasLtMatmulMatrixScale_t scaling_mode_b;
367
#endif  // CUBLAS_VERSION >= 120800
368
    if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
369
370
371
372
373
374
375
376
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
377
#if CUBLAS_VERSION >= 120800
378
379
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
380
#endif  // CUBLAS_VERSION >= 120800
381
    } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
382
383
384
#if CUBLAS_VERSION >= 120800
      NVTE_CHECK(cublas_version() >= 120800,
                 "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
385
386
387
388
389
390
391
392
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
393
394
      scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
395
396
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
397
      if (cublas_version() <= 120803) {
398
399
400
401
402
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
403
404
405
406
#else
      NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120800
407
408
409
410
    } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
               (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
                inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
411
412
413
414
#if CUBLAS_VERSION >= 120900
      NVTE_CHECK(cublas_version() >= 120900,
                 "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
                 cublas_version());
415
416
417
418
419
420
421
422
423
424
      float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
      float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
                    inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
425
                 "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
426
427
428
429
430
431
432
      scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
      scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
                           ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
                           : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
433
434
435
      NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
                 CUBLAS_VERSION);
#endif  // CUBLAS_VERSION >= 120900
436
437
438
439
440
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

441
442
443
444
445
446
447
448
449
450
#if CUBLAS_VERSION >= 120800
    if (cublas_version() >= 120800) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                                       &scaling_mode_a, sizeof(scaling_mode_a)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                                       &scaling_mode_b, sizeof(scaling_mode_b)));
    }
#endif  // CUBLAS_VERSION >= 120800
451
452
453
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
454
455
456
457
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
458
459
460
461
462
463
464
465
466
#if CUBLAS_VERSION >= 120800
      if (cublas_version() >= 120800) {
        // NOTE: In all current cases where FP8 output is supported, the input is
        // scaled identically to the output.
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                         CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
                                                         &scaling_mode_a, sizeof(scaling_mode_a)));
      }
#endif  // CUBLAS_VERSION >= 120800
467
468
469
470
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
471
472
473
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
474
    if (bias) {
475
476
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
477
    }
478
479
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
480
  }
Przemek Tredak's avatar
Przemek Tredak committed
481

482
483
484
485
486
487
488
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
489
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
490
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
491
492
493
494
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
495
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
496
497
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
498
499
500
501
502
503
504
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
505
506
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
507
508
509
510
511
512
513
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
514
515
516
517
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
518
519
520
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
521
  }
Przemek Tredak's avatar
Przemek Tredak committed
522

523
524
525
526
527
528
529
530
  if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
      (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
    NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
                epilogue == CUBLASLT_EPILOGUE_DGELU),
               "Epilogue requested outside of the available and tested cuBLAS functionality for "
               "float8 block scaled GEMM");
  }

531
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
532
                                                   &epilogue, sizeof(epilogue)));
533

534
535
536
537
538
539
540
541
542
543
  if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
    NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
               CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
    NVTE_ERROR(
        "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
        CUBLAS_VERSION);
#endif
544
545
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
    CUBLAS_VERSION < 130000
546
547
548
549
550
551
    NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
               "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
               cuda::cudart_version());
    NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
               "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
               cublas_version());
552
553
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
554
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
555
556
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
557
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
558
559
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
560
561
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
562
563
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
564
565
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
566
567
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
568
569
    }
#endif
570
  }
Przemek Tredak's avatar
Przemek Tredak committed
571

572
573
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
574
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
575
576
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
577
578
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
579
  const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
580
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
581
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
582
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
583
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
584
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
585
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
586
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
587
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
588
589
  NVTE_CHECK(workspace_alignment % 256 == 0,
             "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
Przemek Tredak's avatar
Przemek Tredak committed
590

591
592
593
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
594
595
596
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
597
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
598

599
  // D = alpha * (A * B) + beta * C
600
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
Jan Bielak's avatar
Jan Bielak committed
601
                                   static_cast<const void *>(&alpha),       /* alpha */
602
603
                                   param.A,                                 /* A */
                                   Adesc, param.B,                          /* B */
604
605
606
607
608
609
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
610

611
  // Update FP8 scale-inv in output tensor
612
613
614
615
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
616
617
618
    update_tensor_scale_inv(outputD, stream);
  }

619
620
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
621
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
622
623
624
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
625
}
yuguo's avatar
yuguo committed
626
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
627

yuguo's avatar
yuguo committed
628
629
630
631
632
633
634
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
yuguo's avatar
yuguo committed
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
  int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
  unsigned int cuMask[4];
  unsigned int cuMaskSize = 4;
  if (comm_cu_nums == 4) {
    cuMask[0] = 0xfffffff0;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 8) {
    cuMask[0] = 0xffffff00;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 16) {
    cuMask[0] = 0xffff0000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else if (comm_cu_nums == 32) {
    cuMask[0] = 0x00000000;  
    cuMask[1] = 0xffffffff;
    cuMask[2] = 0xffffffff;
    cuMask[3] = 0xffffffff;
  } else {
    NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
  }
  const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
yuguo's avatar
yuguo committed
662
  for (int i = 0; i < num_batchgemm_streams; i++) {
yuguo's avatar
yuguo committed
663
664
665
666
667
668
669
#ifdef __HIP_PLATFORM_AMD__    
    if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
      NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams_batchgemm[i], cuMaskSize, cuMask));
    } else {
      NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    }
#else
yuguo's avatar
yuguo committed
670
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
yuguo's avatar
yuguo committed
671
#endif
yuguo's avatar
yuguo committed
672
673
674
675
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

676
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
677

678
679
680
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
681
                      int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
682
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
683
  using namespace transformer_engine;
684
685
686
687
688
689
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
690

691
#ifdef __HIP_PLATFORM_AMD__
692
693
694
695
696
697
698
699
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

717
718
719
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);

yuguo's avatar
yuguo committed
720
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
721
722
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
723
724
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && use_split_accumulator) nvte_use_hipblaslt = 1;           
725
726
727
728
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, 
                false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
729
  } else {
yuguo's avatar
yuguo committed
730
    hipblas_gemm(inputA,
yuguo's avatar
yuguo committed
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
yuguo's avatar
yuguo committed
748
  }
749
#else 
750
751
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
Jan Bielak's avatar
Jan Bielak committed
752
753
              1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false,
              nullptr, stream);
754
#endif  //__HIP_PLATFORM_AMD__              
Jan Bielak's avatar
Jan Bielak committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
}

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
                             bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_gemm_scaled);
  using namespace transformer_engine;
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);

  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
              alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
773
774
}

775
776
777
778
779
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
yuguo's avatar
yuguo committed
780
                             cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
781
  NVTE_API_CALL(nvte_cublas_atomic_gemm);
782
  using namespace transformer_engine;
783

yuguo's avatar
yuguo committed
784
#ifndef __HIP_PLATFORM_AMD__
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
  // Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
  NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
             CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
  NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
             CUBLAS_VERSION);
#endif
  NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
             "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
             cuda::cudart_version());
  NVTE_CHECK(
      cublas_version() >= 120205 && cublas_version() < 130000,
      "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
      cublas_version());
yuguo's avatar
yuguo committed
801
#endif
802

803
804
805
806
807
808
809
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  const Tensor *inputCounter = convertNVTETensor(counter);
  Tensor *wspace = convertNVTETensor(workspace);
810

811
812
813
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");
814
#ifdef __HIP_PLATFORM_AMD__
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
835
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
836
837
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
838
839
840
841
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
                wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 
                m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
yuguo's avatar
yuguo committed
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 m_split,
                 n_split,
                 gemm_producer,
                 inputCounter,
                 stream);
yuguo's avatar
yuguo committed
861
  }
862
863
#else 
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
864
              (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
Jan Bielak's avatar
Jan Bielak committed
865
866
              1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split,
              n_split, gemm_producer, inputCounter, stream);
867
#endif  //__HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
868
869
870
}


871

872
873
874
875
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
876
877
878
879
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
880
881

  int num_streams = nvte_get_num_compute_streams();
882

883
  int num_stream_used = std::min(num_streams, num_gemms);
884
  // wait for current stream to finish
885
  NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
886
  for (int s = 0; s < num_stream_used; s++) {
887
888
    NVTE_CHECK_CUDA(
        cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
889
  }
yuguo's avatar
yuguo committed
890
  const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
yuguo's avatar
yuguo committed
891
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
892
893
894
895
896
  bool NVTE_FORCE_BLAS_MULSTREAM;
  if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
    NVTE_FORCE_BLAS_MULSTREAM = true;
    if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'))
      NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
yuguo's avatar
yuguo committed
897
  } else{
yuguo's avatar
yuguo committed
898
    NVTE_FORCE_BLAS_MULSTREAM = false;
yuguo's avatar
yuguo committed
899
  }
yuguo's avatar
yuguo committed
900
  if (NVTE_FORCE_BLAS_MULSTREAM){
yuguo's avatar
yuguo committed
901
    for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
902
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
yuguo's avatar
yuguo committed
903
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
904
                     detail::get_compute_stream(i % num_streams));
yuguo's avatar
yuguo committed
905
906
907
908
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
909
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
910
                     detail::get_compute_stream(i % num_streams), 1, 0, i % num_streams);
yuguo's avatar
yuguo committed
911
    }
912
913
914
915
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
916
917
    NVTE_CHECK_CUDA(
        cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
918
919
920
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
921
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
922
923
  }
}
yuguo's avatar
yuguo committed
924

925
#ifndef __HIP_PLATFORM_AMD__
926
927
928
929
930
931
932
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
933
934
#endif

yuguo's avatar
yuguo committed
935
#ifdef __HIP_PLATFORM_AMD__
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_grouped_gemm);
  using namespace transformer_engine;

  std::vector<const Tensor*> inputA;
  std::vector<const Tensor*> inputB;
  std::vector<Tensor*> outputD;
  std::vector<const Tensor*> biasTensor;
  std::vector<Tensor*> outputGelu;
  std::vector<int64_t> m;
  std::vector<int64_t> n;
  std::vector<int64_t> k;
  std::vector<int64_t> b;
  
  for (int i = 0; i < num_gemms; i++) {
    inputA.push_back(convertNVTETensorCheck(A[i]));
    inputB.push_back(convertNVTETensorCheck(B[i]));
    outputD.push_back(convertNVTETensorCheck(D[i]));
    biasTensor.push_back(convertNVTETensorCheck(bias[i]));
    outputGelu.push_back(convertNVTETensorCheck(pre_gelu_out[i]));
    b.push_back(1);

    size_t A0 = inputA[i]->flat_first_dim();
    size_t A1 = inputA[i]->flat_last_dim();
    size_t B0 = inputB[i]->flat_first_dim();
    size_t B1 = inputB[i]->flat_last_dim();
  
    if (transa) {
      m.push_back(A0);
      k.push_back(A1);
    } else {
      m.push_back(A1);
      k.push_back(A0);
    }
    if (transb) {
      n.push_back(B1);
    } else {
      n.push_back(B0);
    }
  }
  Tensor *wspace = convertNVTETensorCheck(workspace[0]);
  
  if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) {
    NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu.");
  }

  hipblaslt_goupedgemm(inputA, inputB, outputD, m, n, k, b,
                      (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                      (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, 
                      wspace->data.dptr, wspace->data.shape[0],
                      accumulate, use_split_accumulator, 
                      math_sm_count, stream);

  
}
yuguo's avatar
yuguo committed
996

yuguo's avatar
yuguo committed
997
998
999
1000
1001
1002
1003
1004
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
1005
  int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);
yuguo's avatar
yuguo committed
1006
1007
1008
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

yuguo's avatar
yuguo committed
1009
  int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
yuguo's avatar
yuguo committed
1010
1011
1012
1013
1014
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
yuguo's avatar
yuguo committed
1015
  for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
1037
1038
1039
1040
1041
1042
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
yuguo's avatar
yuguo committed
1043
1044
1045
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }
yuguo's avatar
yuguo committed
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
yuguo's avatar
yuguo committed
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
  hipblas_batchgemm(inputA,
            inputB,
            outputD,
            biasTensor,
            outputGelu,
            m, n, k,
            lda, ldb, ldd,
            (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            grad, wspace->data.dptr,
            wspace->data.shape[0],
            accumulate, use_split_accumulator,
            math_sm_count,
            0,
            0,
            false,
            nullptr,
            batch_count,
            stream);
}

// add for batchgemm
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm_v2);
  using namespace transformer_engine;
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
  hipblas_batchgemm(inputA,
            inputB,
            outputD,
            biasTensor,
            outputGelu,
            m, n, k,
            lda, ldb, ldd,
            (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
            grad, wspace->data.dptr,
            wspace->data.shape[0],
            accumulate, use_split_accumulator,
            math_sm_count,
            0,
            0,
            false,
            nullptr,
            batch_count,
            stream);
yuguo's avatar
yuguo committed
1168
}
yuguo's avatar
yuguo committed
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242

// add for batchgemm
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm_v3);
  using namespace transformer_engine;
  const Tensor *inputA = convertNVTETensorCheck(A);
  const Tensor *inputB = convertNVTETensorCheck(B);
  const Tensor *inputA_scales = convertNVTETensorCheck(A_scales);
  const Tensor *inputB_scales = convertNVTETensorCheck(B_scales);
  Tensor *outputD = convertNVTETensor(D);
  const Tensor *biasTensor = convertNVTETensor(bias);
  Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
  Tensor *wspace = convertNVTETensor(workspace);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
    NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
  }

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

  hipblasLtHandle_t handle = nullptr;

  // Init hipblaslt handles (once, globally)
  static std::once_flag init_flag;
  static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
  std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);

  handle = hipblaslt_handles[0];

  hipblaslt_batchgemm_tensorwise_int8(inputA, inputB, inputA_scales, inputB_scales, outputD, biasTensor, outputGelu, 
                 m, n, k, lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad,
                 wspace->data.dptr,
                 wspace->data.shape[0], accumulate, use_split_accumulator,
                 math_sm_count, 0, 0, 
                 false, nullptr, batch_count, stream,
                 handle);

}
yuguo's avatar
yuguo committed
1243
#endif