cublaslt_gemm.cu 36.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

yuguo's avatar
yuguo committed
7
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
8
9
#include <cublasLt.h>
#include <cublas_v2.h>
Tim Moon's avatar
Tim Moon committed
10
#include <cuda.h>
yuguo's avatar
yuguo committed
11
12
13
14
15
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
16
17
18
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

19
#include <cstdint>
20
#include <mutex>
Tim Moon's avatar
Tim Moon committed
21

Przemek Tredak's avatar
Przemek Tredak committed
22
#include "../common.h"
23
#include "../util/handle_manager.h"
Tim Moon's avatar
Tim Moon committed
24
#include "../util/logging.h"
25
#include "common/util/cuda_runtime.h"
Przemek Tredak's avatar
Przemek Tredak committed
26

yuguo's avatar
yuguo committed
27
#ifndef __HIP_PLATFORM_AMD__
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace {

cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

48
49
50
uint32_t _getAlignment(uintptr_t address) {
  // alignment are in bytes
  uint32_t alignment = 256;
51
  for (;; alignment /= 2) {
52
53
54
55
56
57
    if (address % alignment == 0) {
      return alignment;
    }
  }
}

58
59
60
61
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
  NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
struct GemmParam {
  void *A;
  void *B;
  cublasOperation_t transA;
  cublasOperation_t transB;
  transformer_engine::DType Atype;
  transformer_engine::DType Btype;
  void *A_scale_inv;
  void *B_scale_inv;
  int lda;
  int ldb;

  GemmParam(cublasOperation_t transA, cublasOperation_t transB)
      : A(nullptr),
        B(nullptr),
        transA(transA),
        transB(transB),
        Atype(transformer_engine::DType::kNumTypes),
        Btype(transformer_engine::DType::kNumTypes),
        A_scale_inv(nullptr),
        B_scale_inv(nullptr),
        lda(0),
        ldb(0) {}
};

GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
                                const transformer_engine::Tensor &B, const cublasOperation_t transB,
                                const int k, const int lda, const int ldb) {
  using namespace transformer_engine;
  NVTE_CHECK(A.scaling_mode == B.scaling_mode,
             "Inputs A and B to GEMM need to have the same scaling mode!");
  NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
  NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
  GemmParam ret(transA, transB);

  ret.lda = lda;
  ret.ldb = ldb;

  if (is_tensor_scaling(A.scaling_mode)) {
    ret.A = A.data.dptr;
    ret.A_scale_inv = A.scale_inv.dptr;
    if (transA == CUBLAS_OP_T) {
      ret.Atype = A.data.dtype;
    } else {
      ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
      if (is_fp8_dtype(ret.Atype)) {
        int arch = cuda::sm_arch(cuda::current_device());
        if (arch < 100) {
          // Hopper and Ada - we need to use columnwise_data and change transA
          NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
          ret.A = A.columnwise_data.dptr;
          ret.transA = CUBLAS_OP_T;
          ret.A_scale_inv = A.columnwise_scale_inv.dptr;
          ret.lda = k;
        }
      }
    }
    ret.B = B.data.dptr;
    ret.B_scale_inv = B.scale_inv.dptr;
    if (transB == CUBLAS_OP_T) {
      ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
      if (is_fp8_dtype(ret.Btype)) {
        int arch = cuda::sm_arch(cuda::current_device());
        if (arch < 100) {
          // Hopper and Ada - we need to use columnwise_data and change transA
          NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
          ret.B = B.columnwise_data.dptr;
          ret.transB = CUBLAS_OP_N;
          ret.B_scale_inv = B.columnwise_scale_inv.dptr;
          ret.ldb = k;
        }
      }
    } else {
      ret.Btype = B.data.dtype;
    }
  } else {
    // If not tensor scaling (which includes also high precision types), we need to
    // use the proper version of data
    // We leave the transA/B values as is, since Blackwell supports transposes
    ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
    ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
    ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
    ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
    ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
    ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
  }
  return ret;
}

151
}  // namespace
yuguo's avatar
yuguo committed
152
#endif // __HIP_PLATFORM_AMD__
153

Przemek Tredak's avatar
Przemek Tredak committed
154
namespace transformer_engine {
yuguo's avatar
yuguo committed
155
156
157
158
159
160
161
162
163
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad,
                 void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, hipStream_t stream);
#else // Use cublasLt
164
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
165
166
167
168
169
170
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
                 const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
                 void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
                 const Tensor *inputCounter, cudaStream_t stream) {
171
172
173
174
175
176
177
  // Return immediately if GEMM is trivial
  if (m <= 0 || n <= 0) {
    return;
  }
  NVTE_CHECK(k > 0);

  const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb);
178
  void *C = outputD->data.dptr;
179
  void *D = outputD->data.dptr;
180
181
  void *D_scale = outputD->scale.dptr;
  void *D_amax = outputD->amax.dptr;
182
183
184
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
185
186
187
188
  void *counter = nullptr;
  if (inputCounter != nullptr) {
    counter = inputCounter->data.dptr;
  }
189
  const bool gelu = pre_gelu_out != nullptr;
190
191
192
193
  const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);

  const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
  const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
194
195
  const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
  const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
196

197
  NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
198
             "FP8 input to GEMM requires inverse of scale!");
199
  NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
200
             "FP8 input to GEMM requires inverse of scale!");
Przemek Tredak's avatar
Przemek Tredak committed
201

202
203
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
204
205
206
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
207
               "fp8 Aux output for gemm + gelu fusion not supported!");
208
  }
209
  if (is_fp8_dtype(outputD->data.dtype)) {
210
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
211
  }
Przemek Tredak's avatar
Przemek Tredak committed
212

213
214
215
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
Przemek Tredak's avatar
Przemek Tredak committed
216

217
  cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
Przemek Tredak's avatar
Przemek Tredak committed
218

219
220
  cublasLtMatmulDesc_t operationDesc = nullptr;
  cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
221
  cublasLtMatmulPreference_t preference = nullptr;
222
  int returnedResults = 0;
223
224
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
Przemek Tredak's avatar
Przemek Tredak committed
225

226
  int64_t ld_gelumat = (int64_t)ldd;
Przemek Tredak's avatar
Przemek Tredak committed
227

228
229
230
231
232
  // Use TF32 only for pure FP32 GEMM.
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
  if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
    gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
  }
Przemek Tredak's avatar
Przemek Tredak committed
233

234
  // Create matrix descriptors. Not setting any extra attributes.
235
236
237
238
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
                                               param.transA == CUBLAS_OP_N ? k : m, param.lda));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
                                               param.transB == CUBLAS_OP_N ? n : k, param.ldb));
239
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
Przemek Tredak's avatar
Przemek Tredak committed
240

241
242
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
243
                                                   &param.transA, sizeof(param.transA)));
244
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
245
                                                   &param.transB, sizeof(param.transB)));
246
247
  // Set math SM count
  if (math_sm_count != 0) {
248
249
250
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                     CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
                                                     &math_sm_count, sizeof(math_sm_count)));
251
252
  }

253
254
255
256
257
258
  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
259
260
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
                                                     &fastAccuMode, sizeof(fastAccuMode)));
261
262
263
264
265

    // Scaling factors.
#if CUDA_VERSION >= 12080
    cublasLtMatmulMatrixScale_t scaling_mode;
#endif
266
    if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
      void *A_scale_inverse = param.A_scale_inv;
      void *B_scale_inverse = param.B_scale_inv;
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
      scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
    } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) {
      fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
      fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                       &A_scale_inverse, sizeof(A_scale_inverse)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
                                                       CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &B_scale_inverse, sizeof(B_scale_inverse)));
      scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
      // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
      // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
      if (cublasLtGetVersion() <= 120803) {
        const int64_t dummy_a_vec_stride = 1;
        NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
            operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
            sizeof(dummy_a_vec_stride)));
      }
#endif
    } else {
      NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and  " +
                 to_string(inputB->scaling_mode) + ".");
    }

#if CUDA_VERSION >= 12080
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
307
308
309
    if (is_fp8_dtype(outputD->data.dtype)) {
      // Accumulation mode not supported for FP8 output
      C = nullptr;
310
311
312
313
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
314
315
316
317
318
319
320
321
#if CUDA_VERSION >= 12080
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
      // For FP8 output, cuBLAS requires C_type to match bias_type and
      // be FP16/BF16
      const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
322
323
324
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
    }
325
    if (bias) {
326
327
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
328
    }
329
330
  } else {
    NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
331
  }
Przemek Tredak's avatar
Przemek Tredak committed
332

333
334
335
336
337
338
339
  if (bias && gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
340
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
341
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
342
343
344
345
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
346
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
347
348
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
349
350
351
352
353
354
355
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = CUBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = CUBLASLT_EPILOGUE_BIAS;
    }
356
357
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
358
359
360
361
362
363
364
  } else if (gelu) {
    if (grad) {
      epilogue = CUBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
365
366
367
368
                                                     CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                     &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
369
370
371
    const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
        operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
372
  }
Przemek Tredak's avatar
Przemek Tredak committed
373

374
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
375
                                                   &epilogue, sizeof(epilogue)));
376

377
378
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
  if (counter != nullptr) {
379
380
    if (m_split == 0) m_split = 1;
    if (n_split == 0) n_split = 1;
381
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
382
383
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split,
        sizeof(m_split)));
384
    NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
385
386
        operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split,
        sizeof(n_split)));
387
388
    if (gemm_producer) {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
389
390
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter,
          sizeof(counter)));
391
392
    } else {
      NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
393
394
          operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
          sizeof(counter)));
395
396
397
    }
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
398

399
400
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
401
      preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
402
403
  const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
  const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
404
405
406
  const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
  const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
407
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
408
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
409
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment)));
410
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
411
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
412
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
413
      preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
Przemek Tredak's avatar
Przemek Tredak committed
414

415
416
417
  const auto status =
      cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
                                     1, &heuristicResult, &returnedResults);
Tim Moon's avatar
Tim Moon committed
418
419
420
  NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
             "Unable to find suitable cuBLAS GEMM algorithm");
  NVTE_CHECK_CUBLAS(status);
Przemek Tredak's avatar
Przemek Tredak committed
421

422
  if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
Przemek Tredak's avatar
Przemek Tredak committed
423

424
  // D = alpha * (A * B) + beta * C
425
426
  NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
                                   static_cast<const void *>(&one),         /* alpha */
427
428
                                   param.A,                                 /* A */
                                   Adesc, param.B,                          /* B */
429
430
431
432
433
434
                                   Bdesc, static_cast<const void *>(&beta), /* beta */
                                   C,                                       /* C */
                                   Cdesc, D,                                /* D */
                                   Ddesc, &heuristicResult.algo,            /* algo */
                                   workspace,                               /* workspace */
                                   workspaceSize, stream));                 /* stream */
Przemek Tredak's avatar
Przemek Tredak committed
435

436
  // Update FP8 scale-inv in output tensor
437
438
439
440
  // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
  // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
  // calculated here.
  if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
441
442
443
    update_tensor_scale_inv(outputD, stream);
  }

444
445
  NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
446
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
447
448
449
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
Przemek Tredak's avatar
Przemek Tredak committed
450
}
yuguo's avatar
yuguo committed
451
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
452

453
454
455
456
457
458
459
460
461
462
463
464
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];

// Warning: only call once per device!
static void init_streams_and_events() {
  for (int i = 0; i < num_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
  }
}

yuguo's avatar
yuguo committed
465
466
467
468
469
470
471
472
473
474
475
476
477
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];

// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
  for (int i = 0; i < num_batchgemm_streams; i++) {
    NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
    NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
  }
}

478
}  // namespace transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
479

480
481
482
483
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
484
  NVTE_API_CALL(nvte_cublas_gemm);
Przemek Tredak's avatar
Przemek Tredak committed
485
  using namespace transformer_engine;
486
487
488
489
490
491
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
Przemek Tredak's avatar
Przemek Tredak committed
492

493
494
495
496
497
498
499
500
  const size_t A0 = inputA->flat_first_dim();
  const size_t A1 = inputA->flat_last_dim();
  const size_t B0 = inputB->flat_first_dim();
  const size_t B1 = inputB->flat_last_dim();

  const int m = transa ? A0 : A1;
  const int k = transa ? A1 : A0;
  const int n = transb ? B1 : B0;
Przemek Tredak's avatar
Przemek Tredak committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
#ifdef __HIP_PLATFORM_AMD__

#ifdef USE_HIPBLASLT
  const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
    cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else 
  cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif  //USE_HIPBLASLT

#else 
  cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif  //__HIP_PLATFORM_AMD__

#ifdef __HIP_PLATFORM_AMD__
              transa, transb,
#else
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
              grad,
540
541
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, 0, 0, false, nullptr, stream);
yuguo's avatar
yuguo committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
  }
  else{
    hipblas_gemm(inputA,
                inputB,
                outputD,
                biasTensor,
                outputGelu,
                m, n, k,
                lda, ldb, ldd,
                (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                grad, wspace->data.dptr,
                wspace->data.shape[0],
                accumulate, use_split_accumulator,
                math_sm_count,
                0,
                0,
                false,
                nullptr,
                stream);
  }
#endif  //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__

569
570
}

571
572
573
574
575
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
                             const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
                             bool transb, bool grad, NVTETensor workspace, bool accumulate,
                             bool use_split_accumulator, int math_sm_count, int m_split,
                             int n_split, bool gemm_producer, const NVTETensor counter,
576
577
578
                             cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_atomic_gemm);

yuguo's avatar
yuguo committed
579
#ifndef __HIP_PLATFORM_AMD__
580
581
582
583
  int cudart_version;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
  NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
  NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
yuguo's avatar
yuguo committed
584
#endif
585
586

  using namespace transformer_engine;
587
588
589
590
591
592
593
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
594

595
596
597
598
  NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
                 is_delayed_tensor_scaling(inputB->scaling_mode),
             "Atomic GEMM only supports delayed scaling.");

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

yuguo's avatar
yuguo committed
619
620
621
622
623
624
625
626
627
628
629
630
631
#ifdef __HIP_PLATFORM_AMD__

#ifdef USE_HIPBLASLT
  const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
  if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
    cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else 
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif  //USE_HIPBLASLT

#else 
632
  cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
yuguo's avatar
yuguo committed
633
634
635
636
637
638
639
640
#endif  //__HIP_PLATFORM_AMD__  

#ifdef __HIP_PLATFORM_AMD__
              transa, transb,
#else
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
              grad,
641
642
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
yuguo's avatar
yuguo committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
  }
  else{
    hipblas_gemm(
              inputA,
              inputB,
              outputD,
              biasTensor,
              outputGelu,
              m, n, k,
              lda, ldb, ldd,
              (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              grad, wspace->data.dptr,
              wspace->data.shape[0],
              accumulate, use_split_accumulator,
              math_sm_count,
              m_split,
              n_split,
              gemm_producer,
              inputCounter,
              stream);
  }
#endif  //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__

}


void nvte_cublaslt_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublaslt_gemm);
  using namespace transformer_engine;
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);

  const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
  const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
  const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m;
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }

  cublas_gemm(inputA, inputB, outputD,  biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#ifdef __HIP_PLATFORM_AMD__
              transa, transb,
#else
              (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
              grad,
              wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
              math_sm_count, 0, 0, false, nullptr, stream);
Przemek Tredak's avatar
Przemek Tredak committed
716
}
717

718
719
720
721
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
722
723
724
725
726
727
728
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
  using namespace transformer_engine;
  // Inits streams and events (once, globally)
  std::call_once(init_flag, init_streams_and_events);

729
  int num_stream_used = std::min(num_streams, num_gemms);
730
731
732
733
734
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
  }
yuguo's avatar
yuguo committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
  const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
  const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
  bool NVTE_FORCE_BLASLT_MULSTREAM;

  if(NVTE_BLAS_MULSTREAM==nullptr){
    NVTE_FORCE_BLASLT_MULSTREAM = true;
  } elif((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){
    NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time.");
  } else{
    NVTE_FORCE_BLASLT_MULSTREAM = false;
  }
  if (NVTE_FORCE_BLASLT_MULSTREAM){
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublaslt_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
    }
  } else{
    for (int i = 0; i < num_gemms; i++) {
      nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
755
756
                     workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
                     compute_streams[i % num_streams]);
yuguo's avatar
yuguo committed
757
    }
758
759
760
761
762
763
764
765
766
767
768
  }

  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
  }
}
yuguo's avatar
yuguo committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872

#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
                                   const NVTETensor *bias, NVTETensor *pre_gelu_out,
                                   const int num_gemms, bool transa, bool transb, bool grad,
                                   NVTETensor *workspace, bool accumulate,
                                   bool use_split_accumulator, int math_sm_count,
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
  using namespace transformer_engine;
  static_assert(num_gemms % num_batchgemm_streams == 0,
                "Need num_gemms mod num_batchgemm_streams == 0.");
  static int batch_count = num_gemms / num_batchgemm_streams;
  // Inits streams and events (once, globally)
  std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);

  int num_stream_used = num_batchgemm_streams;
  // wait for current stream to finish
  NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
  }
  for (int i = 0; i < num_stream_used; i++) {
    nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
                     workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
                     batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
  }
  // record events on compute streams
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
  }
  // wait for all compute streams to finish
  for (int s = 0; s < num_stream_used; s++) {
    NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
  }
}

// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
                      NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
                      NVTETensor workspace, bool accumulate, bool use_split_accumulator,
                      int math_sm_count, int batch_count, cudaStream_t stream) {
  NVTE_API_CALL(nvte_cublas_batchgemm);
  using namespace transformer_engine;
  const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
  const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
  Tensor *outputD = reinterpret_cast<Tensor *>(D);
  const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
  Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
  Tensor *wspace = reinterpret_cast<Tensor *>(workspace);

  int m, n, k;
  if (!transa && transb) {
  // for NT
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  }  else if(transa && !transb){
  // for TN
  m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
  } else if(!transa && !transb){
  // for NN
  m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
  k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
  n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
  int lda, ldb, ldd;
  if (transa && !transb) {  // TN
    lda = k;
    ldb = k;
    ldd = m;
  } else if (!transa && !transb) {  // NN
    lda = m; 
    ldb = k;
    ldd = m;
  } else if (!transa && transb) {  // NT
    lda = m;
    ldb = n;
    ldd = m;
  } else {  // TT
    NVTE_ERROR("TT layout not allowed.");
  }
    hipblas_batchgemm(inputA,
              inputB,
              outputD,
              biasTensor,
              outputGelu,
              m, n, k,
              lda, ldb, ldd,
              (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
              grad, wspace->data.dptr,
              wspace->data.shape[0],
              accumulate, use_split_accumulator,
              math_sm_count,
              0,
              0,
              false,
              nullptr,
              batch_count,
              stream);
}
#endif