cublaslt_gemm.cu 35.4 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
17
18
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

19
#include <cstdint>
20
#include <mutex>
Tim Moon's avatar
Tim Moon committed
21

Przemek Tredak's avatar
Przemek Tredak committed
22
#include "../common.h"
23
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
24
#include "../util/logging.h"
yuguo's avatar
yuguo committed
25
26
27
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
28
#include "common/util/cuda_runtime.h"
yuguo's avatar
yuguo committed
29
#endif
Przemek Tredak's avatar
Przemek Tredak committed
30

yuguo's avatar
yuguo committed
31
#ifndef __HIP_PLATFORM_AMD__
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

52
53
54
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
55
  for (;; alignment /= 2) {
56
57
58
59
60
61
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

62
63
64
65
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
struct GemmParam {
  void *A;
  void *B;
  cublasOperation_t transA;
  cublasOperation_t transB;
  transformer_engine::DType Atype;
  transformer_engine::DType Btype;
  void *A_scale_inv;
  void *B_scale_inv;
  int lda;
  int ldb;

  GemmParam(cublasOperation_t transA, cublasOperation_t transB)
      : A(nullptr),
        B(nullptr),
        transA(transA),
        transB(transB),
        Atype(transformer_engine::DType::kNumTypes),
        Btype(transformer_engine::DType::kNumTypes),
        A_scale_inv(nullptr),
        B_scale_inv(nullptr),
        lda(0),
        ldb(0) {}
};

GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
                                const int k, const int lda, const int ldb) {
  using namespace transformer_engine;
  NVTE_CHECK(A.scaling_mode == B.scaling_mode,
             "Inputs A and B to GEMM need to have the same scaling mode!");
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
  GemmParam ret(transA, transB);

  ret.lda = lda;
  ret.ldb = ldb;

  if (is_tensor_scaling(A.scaling_mode)) {
    ret.A = A.data.dptr;
    ret.A_scale_inv = A.scale_inv.dptr;
    if (transA == CUBLAS_OP_T) {
      ret.Atype = A.data.dtype;
    } else {
      ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
      if (is_fp8_dtype(ret.Atype)) {
        int arch = cuda::sm_arch(cuda::current_device());
        if (arch < 100) {
          // Hopper and Ada - we need to use columnwise_data and change transA
          NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
          ret.A = A.columnwise_data.dptr;
          ret.transA = CUBLAS_OP_T;
          ret.A_scale_inv = A.columnwise_scale_inv.dptr;
          ret.lda = k;
        }
      }
    }
    ret.B = B.data.dptr;
    ret.B_scale_inv = B.scale_inv.dptr;
    if (transB == CUBLAS_OP_T) {
      ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
      if (is_fp8_dtype(ret.Btype)) {
        int arch = cuda::sm_arch(cuda::current_device());
        if (arch < 100) {
          // Hopper and Ada - we need to use columnwise_data and change transA
          NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
          ret.B = B.columnwise_data.dptr;
          ret.transB = CUBLAS_OP_N;
          ret.B_scale_inv = B.columnwise_scale_inv.dptr;
          ret.ldb = k;
        }
      }
    } else {
      ret.Btype = B.data.dtype;
    }
  } else {
    // If not tensor scaling (which includes also high precision types), we need to
    // use the proper version of data
    // We leave the transA/B values as is, since Blackwell supports transposes
    ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
    ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
    ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
  }
  return ret;
}

155
}  // namespace
yuguo's avatar
yuguo committed
156
#endif // __HIP_PLATFORM_AMD__
157

Przemek Tredak's avatar
Przemek Tredak committed
158
namespace transformer_engine {
yuguo's avatar
yuguo committed
159
160
161
162
163
164
165
166
167
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, hipStream_t stream);
#else // Use cublasLt
168
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
169
170
171
172
173
174
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
                 void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, cudaStream_t stream) {
175
176
177
178
179
180
181
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

  const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb);
182
  void *C = outputD->data.dptr;
183
  void *D = outputD->data.dptr;
184
185
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
186
187
188
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
189
190
191
192
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
193
  const bool gelu = pre_gelu_out != nullptr;
194
195
196
197
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
198
199
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
200

201
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
202
             "FP8 input to GEMM requires inverse of scale!");
203
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
204
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
205

206
207
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
208
209
210
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
211
               "fp8 Aux output for gemm + gelu fusion not supported!");
212
  }
213
  if (is_fp8_dtype(outputD->data.dtype)) {
214
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
215
  }
Przemek Tredak's avatar
Przemek Tredak committed
216

217
218
219
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
220

221
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
222

223
224
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
225
  cublasLtMatmulPreference_t preference = nullptr;
226
  int returnedResults = 0;
227
228
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
229

230
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
231

232
233
234
235
236
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
237

238
  // Create matrix descriptors. Not setting any extra attributes.
239
240
241
242
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
243
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
244

245
246
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
247
                                                   &param.transA, sizeof(param.transA)));
248
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
249
                                                   &param.transB, sizeof(param.transB)));
250
251
  // Set math SM count
  if (math_sm_count != 0) {
252
253
254
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
255
256
  }

257
258
259
260
261
262
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
263
264
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
265
266
267
268
269

    // Scaling factors.
#if CUDA_VERSION >= 12080
    cublasLtMatmulMatrixScale_t scaling_mode;
#endif
270
    if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
      scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
    } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) {
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
      if (cublasLtGetVersion() <= 120803) {
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
#endif
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

#if CUDA_VERSION >= 12080
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
311
312
313
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
314
315
316
317
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
318
319
320
321
322
323
324
325
#if CUDA_VERSION >= 12080
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
326
327
328
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
329
    if (bias) {
330
331
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
332
    }
333
334
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
335
  }
Przemek Tredak's avatar
Przemek Tredak committed
336

337
338
339
340
341
342
343
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
344
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
345
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
346
347
348
349
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
350
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
351
352
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
353
354
355
356
357
358
359
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
360
361
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
362
363
364
365
366
367
368
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
369
370
371
372
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
373
374
375
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
376
  }
Przemek Tredak's avatar
Przemek Tredak committed
377

378
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
379
                                                   &epilogue, sizeof(epilogue)));
380

381
382
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
383
384
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
385
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
386
387
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
388
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
389
390
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
391
392
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
393
394
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
395
396
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
397
398
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
399
400
401
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
402

403
404
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
405
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
406
407
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
408
409
410
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
411
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
412
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
413
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
414
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
415
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
416
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
417
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
418

419
420
421
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
422
423
424
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
Przemek Tredak's avatar
Przemek Tredak committed
425

426
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
427

428
  // D = alpha * (A * B) + beta * C
429
430
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
431
432
                                   param.A,                                 /* A */
                                   Adesc, param.B,                          /* B */
433
434
435
436
437
438
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
439

440
  // Update FP8 scale-inv in output tensor
441
442
443
444
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
445
446
447
    update_tensor_scale_inv(outputD, stream);
  }

448
449
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
450
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
451
452
453
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
454
}
yuguo's avatar
yuguo committed
455
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
456

457
458
459
460
461
462
463
464
465
466
467
468
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
  for (int i = 0; i < num_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

yuguo's avatar
yuguo committed
469
470
471
472
473
474
475
476
477
478
479
480
481
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
  for (int i = 0; i < num_batchgemm_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

482
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
483

484
485
486
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
487
                      int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
488
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
489
  using namespace transformer_engine;
490
491
492
493
494
495
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
496

497
498
499
500
501
502
503
504
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
522
523
#ifdef __HIP_PLATFORM_AMD__

yuguo's avatar
yuguo committed
524
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
525
526
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
527
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){
yuguo's avatar
yuguo committed
528
529
530
531
532
533
534
535
536
537
538
539
    cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,

#else 
  cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif  //__HIP_PLATFORM_AMD__

#ifdef __HIP_PLATFORM_AMD__
              transa, transb,
#else
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
              grad,
540
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
yuguo's avatar
yuguo committed
541
542
543
#ifdef __HIP_PLATFORM_AMD__
              math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas);
#else
544
              math_sm_count, 0, 0, false, nullptr, stream);
yuguo's avatar
yuguo committed
545
#endif
yuguo's avatar
yuguo committed
546
547

#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
548
  } else {
yuguo's avatar
yuguo committed
549
    hipblas_gemm(inputA,
yuguo's avatar
yuguo committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 0,
                 0,
                 false,
                 nullptr,
                 stream);
yuguo's avatar
yuguo committed
567
568
569
  }
#endif //__HIP_PLATFORM_AMD__

570
571
}

572
573
574
575
576
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
yuguo's avatar
yuguo committed
577
                             cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
578
579
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

yuguo's avatar
yuguo committed
580
#ifndef __HIP_PLATFORM_AMD__
581
582
583
584
  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
yuguo's avatar
yuguo committed
585
#endif
586
587

  using namespace transformer_engine;
588
589
590
591
592
593
594
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
595

596
597
598
599
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
620
621
#ifdef __HIP_PLATFORM_AMD__

yuguo's avatar
yuguo committed
622
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
yuguo's avatar
yuguo committed
623
624
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
625
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){
yuguo's avatar
yuguo committed
626
627
628
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,

#else 
629
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
yuguo's avatar
yuguo committed
630
631
632
633
634
635
636
637
#endif  //__HIP_PLATFORM_AMD__  

#ifdef __HIP_PLATFORM_AMD__
              transa, transb,
#else
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
              grad,
638
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
yuguo's avatar
yuguo committed
639
640
641
#ifdef __HIP_PLATFORM_AMD__
              math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas);
#else
642
              math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
yuguo's avatar
yuguo committed
643
#endif
yuguo's avatar
yuguo committed
644
645

#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
  } else {
    hipblas_gemm(inputA,
                 inputB,
                 outputD,
                 biasTensor,
                 outputGelu,
                 m, n, k,
                 lda, ldb, ldd,
                 (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                 grad, wspace->data.dptr,
                 wspace->data.shape[0],
                 accumulate, use_split_accumulator,
                 math_sm_count,
                 m_split,
                 n_split,
                 gemm_producer,
                 inputCounter,
                 stream);
yuguo's avatar
yuguo committed
665
666
667
668
669
670
  }
#endif //__HIP_PLATFORM_AMD__

}


671

672
673
674
675
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
676
677
678
679
680
681
682
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

683
  int num_stream_used = std::min(num_streams, num_gemms);
684
685
686
687
688
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }
yuguo's avatar
yuguo committed
689
690
691
692
693
694
695
  const char *NVTE_HIPBLAS_MULSTREAM = std::getenv("NVTE_FORCE_HIPBLAS_MULSTREAM");
  const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
  bool NVTE_FORCE_HIPBLAS_MULSTREAM;
  if(NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'){
    NVTE_FORCE_HIPBLAS_MULSTREAM = true;
    if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'))
      NVTE_ERROR("NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
yuguo's avatar
yuguo committed
696
  } else{
yuguo's avatar
yuguo committed
697
    NVTE_FORCE_HIPBLAS_MULSTREAM = false;
yuguo's avatar
yuguo committed
698
  }
yuguo's avatar
yuguo committed
699
  if (NVTE_FORCE_HIPBLAS_MULSTREAM){
yuguo's avatar
yuguo committed
700
    for (int i = 0; i < num_gemms; i++) {
yuguo's avatar
yuguo committed
701
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
yuguo's avatar
yuguo committed
702
703
704
705
706
707
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
708
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
yuguo's avatar
yuguo committed
709
                     compute_streams[i % num_streams], 1, 0);
yuguo's avatar
yuguo committed
710
    }
711
712
713
714
715
716
717
718
719
720
721
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}
yuguo's avatar
yuguo committed
722

723
#ifndef __HIP_PLATFORM_AMD__
724
725
726
727
728
729
730
namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

}  //  namespace transformer_engine
731
732
#endif

yuguo's avatar
yuguo committed
733
734
735
736
737
738
739
740
741
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
yuguo's avatar
yuguo committed
742
  assert(num_gemms % num_batchgemm_streams == 0);
yuguo's avatar
yuguo committed
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
  static int batch_count = num_gemms / num_batchgemm_streams;
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

  int num_stream_used = num_batchgemm_streams;
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
  for (int i = 0; i < num_stream_used; i++) {
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
    hipblas_batchgemm(inputA,
              inputB,
              outputD,
              biasTensor,
              outputGelu,
              m, n, k,
              lda, ldb, ldd,
              (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              grad, wspace->data.dptr,
              wspace->data.shape[0],
              accumulate, use_split_accumulator,
              math_sm_count,
              0,
              0,
              false,
              nullptr,
              batch_count,
              stream);
}
#endif