rocm_gemm.cu 102 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
/*************************************************************************
 * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
 *
 * License for AMD contributions = MIT. See LICENSE for more information
 ************************************************************************/
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
wenjh's avatar
wenjh committed
8
9
10

#include <type_traits>

yuguo's avatar
yuguo committed
11
#ifdef USE_HIPBLASLT
wenjh's avatar
wenjh committed
12
#include <hipblaslt/hipblaslt.h>
yuguo's avatar
yuguo committed
13
#include <unistd.h>
wenjh's avatar
wenjh committed
14
15

#include <chrono>
yuguo's avatar
yuguo committed
16
17
#include <forward_list>
#include <fstream>
wenjh's avatar
wenjh committed
18
#include <mutex>
yuguo's avatar
yuguo committed
19
#include <optional>
wenjh's avatar
wenjh committed
20
21
22
23
#include <sstream>
#include <unordered_map>
#include <vector>

yuguo's avatar
yuguo committed
24
25
#endif
#ifdef USE_ROCBLAS
wenjh's avatar
wenjh committed
26
#define ROCBLAS_BETA_FEATURES_API
yuguo's avatar
yuguo committed
27
#include <rocblas/rocblas.h>
wenjh's avatar
wenjh committed
28

29
#include <hipblaslt/hipblaslt-ext.hpp>
wenjh's avatar
wenjh committed
30
31
#include <hipcub/hipcub.hpp>

yuguo's avatar
yuguo committed
32
#endif
wenjh's avatar
wenjh committed
33
#include <cstdint>
yuguo's avatar
yuguo committed
34
#include <cstdlib>
wenjh's avatar
wenjh committed
35
#include <iostream>
yuguo's avatar
yuguo committed
36
37
38
#include <string>

#include "../common.h"
39
#include "../util/handle_manager.h"
yuguo's avatar
yuguo committed
40
#include "../util/logging.h"
wenjh's avatar
wenjh committed
41
42
#include "../util/vectorized_pointwise.h"

yuguo's avatar
yuguo committed
43
44
45
46
47

namespace {

#ifdef USE_HIPBLASLT

yuguo's avatar
yuguo committed
48
static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
yuguo's avatar
yuguo committed
49
50
51
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
yuguo's avatar
yuguo committed
52
      return HIP_R_16F;
yuguo's avatar
yuguo committed
53
    case DType::kFloat32:
yuguo's avatar
yuguo committed
54
      return HIP_R_32F;
yuguo's avatar
yuguo committed
55
    case DType::kBFloat16:
yuguo's avatar
yuguo committed
56
      return HIP_R_16BF;
yuguo's avatar
yuguo committed
57
    case DType::kFloat8E4M3:
wenjh's avatar
wenjh committed
58
      return HIP_R_8F_E4M3;
yuguo's avatar
yuguo committed
59
    case DType::kFloat8E5M2:
wenjh's avatar
wenjh committed
60
      return HIP_R_8F_E5M2;
61
62
63
    case DType::kInt8:
      return HIP_R_8I;
    case DType::kInt32:
wenjh's avatar
wenjh committed
64
      return HIP_R_32I;
yuguo's avatar
yuguo committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    default:
      NVTE_ERROR("Invalid type");
  }
}
#endif

#ifdef USE_ROCBLAS
rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return rocblas_datatype_f16_r;
    case DType::kFloat32:
      return rocblas_datatype_f32_r;
    case DType::kBFloat16:
      return rocblas_datatype_bf16_r;
    case DType::kFloat8E4M3:
      return rocblas_datatype_f8_r;
    case DType::kFloat8E5M2:
      return rocblas_datatype_bf8_r;
    default:
      NVTE_ERROR("Invalid type");
  }
}
#endif

wenjh's avatar
wenjh committed
91
}  //namespace
yuguo's avatar
yuguo committed
92
93
94
95
96
97
98
99
100

namespace transformer_engine {

#ifdef USE_ROCBLAS

namespace detail {

struct Empty {};

wenjh's avatar
wenjh committed
101
__device__ inline fp32 identity(fp32 value, const Empty&) { return value; }
yuguo's avatar
yuguo committed
102

wenjh's avatar
wenjh committed
103
__inline__ __device__ float gelu(float x, const Empty&) {
yuguo's avatar
yuguo committed
104
105
106
107
  float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
  return x * cdf;
}

wenjh's avatar
wenjh committed
108
__inline__ __device__ float gelu_forward(float x) {
yuguo's avatar
yuguo committed
109
110
111
112
113
  float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
  return x * cdf;
}

template <typename T, int THREADS_PER_BLOCK>
wenjh's avatar
wenjh committed
114
115
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
    gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) {
yuguo's avatar
yuguo committed
116
  // fp8 output flow
wenjh's avatar
wenjh committed
117
  if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
yuguo's avatar
yuguo committed
118
119
120
    typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
    __shared__ typename BlockReduce::TempStorage block_temp_storage;
    float thread_amax = 0;
wenjh's avatar
wenjh committed
121
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
122
      float x = in[id];
wenjh's avatar
wenjh committed
123
124
125
      float y = gelu_forward(x);
      out[id] = (T)((*scale) * y);
      thread_amax = std::fmax(std::fabs(y), thread_amax);
yuguo's avatar
yuguo committed
126
127
    }
    float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
wenjh's avatar
wenjh committed
128
    if (threadIdx.x == 0) {
yuguo's avatar
yuguo committed
129
130
      atomicMaxFloat(amax, block_amax);
    }
wenjh's avatar
wenjh committed
131
132
  } else {
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
133
      float x = in[id];
wenjh's avatar
wenjh committed
134
      float y = gelu_forward(x);
yuguo's avatar
yuguo committed
135
136
137
138
139
140
      out[id] = (T)(y);
    }
  }
}

template <typename T>
wenjh's avatar
wenjh committed
141
142
void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m,
                                 int n, hipStream_t stream) {
yuguo's avatar
yuguo committed
143
144
145
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
  block.x = THREADS_PER_BLOCK;
wenjh's avatar
wenjh committed
146
147
148
  grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
  hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
                     stream, in, out, amax, scale, m, n);
yuguo's avatar
yuguo committed
149
150
}

wenjh's avatar
wenjh committed
151
152
__inline__ __device__ float gelu_backward(float x, float dy) {
  constexpr float kBeta = 0.7978845608028654f;
yuguo's avatar
yuguo committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  constexpr float kKappa = 0.044715f;
  float x_sq = x * x;
  float x_cube = x_sq * x;
  float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube)));

  float left = 0.5 * x;
  float right = 1.0f + tanh_inner;

  float left_derivative = 0.5 * right;

  float tanh_derivative = 1 - tanh_inner * tanh_inner;
  float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq);
  float right_derivative = left * tanh_derivative * inner_derivative;

  return dy * (left_derivative + right_derivative);
}

template <typename T, typename Taux>
wenjh's avatar
wenjh committed
171
172
173
__global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out,
                                     int m, int n) {
  for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
174
    float x = (float)pre_gelu_out[id];
wenjh's avatar
wenjh committed
175
    float dx = (float)gelu_backward(x, dy[id]);
yuguo's avatar
yuguo committed
176
177
178
179
180
    out[id] = (T)(dx);
  }
}

template <typename T, typename Taux>
wenjh's avatar
wenjh committed
181
182
183
void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n,
                                  hipStream_t stream) {
  int blocks_per_row = ceil(float(n) / 256);
yuguo's avatar
yuguo committed
184
  dim3 grid(min(m * blocks_per_row, 65536));
wenjh's avatar
wenjh committed
185
  dim3 block(min(n, 256));
wenjh's avatar
wenjh committed
186
187
  hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out,
                     pre_gelu_out, m, n);
yuguo's avatar
yuguo committed
188
189
190
}

template <typename T, typename Tb, int THREADS_PER_BLOCK>
wenjh's avatar
wenjh committed
191
192
193
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
    add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax,
                    const float* scale, int m, int n) {
yuguo's avatar
yuguo committed
194
  // fp8 output flow
wenjh's avatar
wenjh committed
195
  if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
yuguo's avatar
yuguo committed
196
197
198
    typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
    __shared__ typename BlockReduce::TempStorage block_temp_storage;
    float thread_amax = 0;
wenjh's avatar
wenjh committed
199
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
200
201
      float reg_bias = (float)bias[id % n];
      float val = in[id] + reg_bias;
wenjh's avatar
wenjh committed
202
      out[id] = (T)((*scale) * val);
yuguo's avatar
yuguo committed
203
      // deal with amax of D
wenjh's avatar
wenjh committed
204
      thread_amax = std::fmax(std::fabs(val), thread_amax);
yuguo's avatar
yuguo committed
205
206
207
    }
    // num_valid can be ignored since each thread amax is set to 0
    float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
wenjh's avatar
wenjh committed
208
    if (threadIdx.x == 0) {
yuguo's avatar
yuguo committed
209
210
      atomicMaxFloat(amax, block_amax);
    }
wenjh's avatar
wenjh committed
211
212
  } else {
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
213
214
215
216
217
218
219
220
      float reg_bias = (float)bias[id % n];
      float val = in[id] + reg_bias;
      out[id] = (T)(val);
    }
  }
}

template <typename T, typename Tb>
wenjh's avatar
wenjh committed
221
222
void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax,
                             const float* scale, int m, int n, hipStream_t stream) {
yuguo's avatar
yuguo committed
223
224
225
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
  block.x = THREADS_PER_BLOCK;
wenjh's avatar
wenjh committed
226
227
228
  grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
  hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
                     stream, in, out, bias, amax, scale, m, n);
yuguo's avatar
yuguo committed
229
230
231
}

template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK>
wenjh's avatar
wenjh committed
232
233
234
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
    add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias,
                         float* amax, const float* scale, int m, int n) {
yuguo's avatar
yuguo committed
235
  // fp8 output flow
wenjh's avatar
wenjh committed
236
  if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
yuguo's avatar
yuguo committed
237
238
239
240
    // only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out
    typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
    __shared__ typename BlockReduce::TempStorage block_temp_storage;
    float thread_amax = 0;
wenjh's avatar
wenjh committed
241
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
242
243
244
245
246
      float reg_bias = (float)bias[id % n];
      float val = in[id] + reg_bias;
      // pre_gelu_out guaranteed not to be fp8 type
      pre_gelu_out[id] = (Taux)(val);
      val = gelu_forward(val);
wenjh's avatar
wenjh committed
247
      out[id] = (T)((*scale) * val);
yuguo's avatar
yuguo committed
248
      // deal with amax of D
wenjh's avatar
wenjh committed
249
      thread_amax = std::fmax(std::fabs(val), thread_amax);
yuguo's avatar
yuguo committed
250
251
252
    }
    // num_valid can be ignored since each thread amax is set to 0
    float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
wenjh's avatar
wenjh committed
253
    if (threadIdx.x == 0) {
yuguo's avatar
yuguo committed
254
255
      atomicMaxFloat(amax, block_amax);
    }
wenjh's avatar
wenjh committed
256
257
  } else {
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
258
259
260
261
262
263
264
265
266
      float reg_bias = (float)bias[id % n];
      float val = in[id] + reg_bias;
      pre_gelu_out[id] = (Taux)(val);
      out[id] = (T)(gelu_forward(val));
    }
  }
}

template <typename T, typename Taux, typename Tb>
wenjh's avatar
wenjh committed
267
268
269
void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out,
                                  const Tb* __restrict bias, float* amax, const float* scale, int m,
                                  int n, hipStream_t stream) {
yuguo's avatar
yuguo committed
270
271
272
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
  block.x = THREADS_PER_BLOCK;
wenjh's avatar
wenjh committed
273
274
275
  grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
  hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid),
                     dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n);
yuguo's avatar
yuguo committed
276
277
278
}

template <typename Tin, typename T>
wenjh's avatar
wenjh committed
279
280
__global__ void identity_kernel(const Tin* in, T* out, int n) {
  for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
281
282
283
284
285
286
287
288
    Tin val = in[id];
    out[id] = (T)(val);
  }
}

template <typename Tin, typename T>
void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) {
  dim3 block, grid;
wenjh's avatar
wenjh committed
289
  block.x = 256;
wenjh's avatar
wenjh committed
290
291
  grid.x = ceil(n / 256.);
  hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n);
yuguo's avatar
yuguo committed
292
293
294
}

template <typename T, int THREADS_PER_BLOCK>
wenjh's avatar
wenjh committed
295
296
297
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
    identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) {
  if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
yuguo's avatar
yuguo committed
298
299
300
    typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
    __shared__ typename BlockReduce::TempStorage block_temp_storage;
    float thread_amax = 0;
wenjh's avatar
wenjh committed
301
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
302
      float val = in[id];
wenjh's avatar
wenjh committed
303
      out[id] = (T)((*scale) * val);
yuguo's avatar
yuguo committed
304
      // deal with amax of D
wenjh's avatar
wenjh committed
305
      thread_amax = std::fmax(std::fabs(val), thread_amax);
yuguo's avatar
yuguo committed
306
307
308
    }
    // num_valid can be ignored since each thread amax is set to 0
    float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
wenjh's avatar
wenjh committed
309
    if (threadIdx.x == 0) {
yuguo's avatar
yuguo committed
310
311
      atomicMaxFloat(amax, block_amax);
    }
wenjh's avatar
wenjh committed
312
313
  } else {
    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
yuguo's avatar
yuguo committed
314
315
316
317
318
319
320
      float val = in[id];
      out[id] = (T)(val);
    }
  }
}

template <typename T>
wenjh's avatar
wenjh committed
321
322
void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n,
                                    hipStream_t stream) {
yuguo's avatar
yuguo committed
323
324
325
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
  block.x = THREADS_PER_BLOCK;
wenjh's avatar
wenjh committed
326
327
328
  grid.x = ceil(1.0 * n / THREADS_PER_BLOCK);
  hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
                     stream, in, out, amax, scale, n);
yuguo's avatar
yuguo committed
329
330
331
}

template <typename Tin, int THREADS_PER_BLOCK>
wenjh's avatar
wenjh committed
332
333
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
    bias_gradient_kernel(const Tin* in, float* out, int m, int n) {
yuguo's avatar
yuguo committed
334
335
336
  typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
  __shared__ typename BlockReduce::TempStorage block_temp_storage;

wenjh's avatar
wenjh committed
337
  int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
yuguo's avatar
yuguo committed
338
339
340
341
342
  int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int col_idx = idx / THREADS_PER_COL;
  int row_idx = idx % THREADS_PER_COL;
  float thread_data;
wenjh's avatar
wenjh committed
343
  if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx];
yuguo's avatar
yuguo committed
344
  float local_sum;
wenjh's avatar
wenjh committed
345
  if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) {
yuguo's avatar
yuguo committed
346
    local_sum = BlockReduce(block_temp_storage).Sum(thread_data);
wenjh's avatar
wenjh committed
347
348
349
  } else {
    local_sum = BlockReduce(block_temp_storage)
                    .Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
yuguo's avatar
yuguo committed
350
  }
wenjh's avatar
wenjh committed
351
  if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum);
yuguo's avatar
yuguo committed
352
353
}

yuguo's avatar
yuguo committed
354
355
356
357
358
359
360
361
362
363
364
365
constexpr int kColwiseReduceTileSize = 32;

template <typename T>
__inline__ __device__ T WarpReduceSum(T val, int max = 32) {
  for (int offset = max; offset > 0; offset >>= 1) {
    val += __shfl_down(val, offset);
  }
  return val;
}

template <typename InputType>
__launch_bounds__(1024) __global__
wenjh's avatar
wenjh committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
  __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
  const int j = blockIdx.x * blockDim.x + threadIdx.x;
  float grad_sum = 0.f;
  if (j < N) {
    for (int i = threadIdx.y; i < M; i += blockDim.y) {
      grad_sum += static_cast<float>(src[i * N + j]);
    }
  }
  g_shared[threadIdx.y][threadIdx.x] = grad_sum;
  __syncthreads();
  float sum = g_shared[threadIdx.x][threadIdx.y];
  sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
  if (threadIdx.x == 0) {
    const int j = blockIdx.x * blockDim.x + threadIdx.y;
yuguo's avatar
yuguo committed
381
    if (j < N) {
wenjh's avatar
wenjh committed
382
      dst[j] = static_cast<float>(sum);
yuguo's avatar
yuguo committed
383
    }
wenjh's avatar
wenjh committed
384
  }
yuguo's avatar
yuguo committed
385
386
}

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
template <typename OutputType>
__launch_bounds__(1024) __global__
    void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) {
  __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
  const int j = blockIdx.x * blockDim.x + threadIdx.x;
  float grad_sum = 0.f;
  float tensorwise_scale = scale[0];
  if (j < N) {
    for (int i = threadIdx.y; i < M; i += blockDim.y) {
      grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale;
    }
  }
  g_shared[threadIdx.y][threadIdx.x] = grad_sum;
  __syncthreads();
  float sum = g_shared[threadIdx.x][threadIdx.y];
  sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
  if (threadIdx.x == 0) {
    const int j = blockIdx.x * blockDim.x + threadIdx.y;
    if (j < N) {
      dst[j] = static_cast<OutputType>(sum);
    }
  }
}

yuguo's avatar
yuguo committed
411
template <typename Tin>
wenjh's avatar
wenjh committed
412
413
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
                                  hipStream_t stream) {
yuguo's avatar
yuguo committed
414
415
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
wenjh's avatar
wenjh committed
416
  int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
yuguo's avatar
yuguo committed
417
  block.x = THREADS_PER_BLOCK;
wenjh's avatar
wenjh committed
418
419
420
421
422
  grid.x = BLOCKS_PER_COL * n;
  if (!stream_order_alloc) {
    NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
  } else {
    NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
yuguo's avatar
yuguo committed
423
  }
yuguo's avatar
yuguo committed
424
  // hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
wenjh's avatar
wenjh committed
425
426
427
  int B = (n - 1) / kColwiseReduceTileSize + 1;
  bias_gradient_kernel_v2<Tin>
      <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
yuguo's avatar
yuguo committed
428
429
}

430
431
432
433
434
435
436
437
438
439
440
441
442
template <typename Tout>
void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) {
  dim3 block, grid;
  constexpr int THREADS_PER_BLOCK = 1024;
  int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
  block.x = THREADS_PER_BLOCK;
  grid.x = BLOCKS_PER_COL * n;
  NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
  int B = (n - 1) / kColwiseReduceTileSize + 1;
  tensorwise_int8_bias_gradient_kernel<Tout>
      <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n);
}

wenjh's avatar
wenjh committed
443
}  // namespace detail
yuguo's avatar
yuguo committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
  using namespace transformer_engine;
  switch (t) {
    case rocblas_datatype_f16_r:
      return DType::kFloat16;
    case rocblas_datatype_f32_r:
      return DType::kFloat32;
    case rocblas_datatype_bf16_r:
      return DType::kBFloat16;
    case rocblas_datatype_f8_r:
      return DType::kFloat8E4M3;
    case rocblas_datatype_bf8_r:
      return DType::kFloat8E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}
wenjh's avatar
wenjh committed
462
#endif  //USE_ROCBLAS
yuguo's avatar
yuguo committed
463
464
465
466
467
468

#ifdef USE_HIPBLASLT

namespace {

static class HandlePool {
wenjh's avatar
wenjh committed
469
470
 public:
  hipblasLtHandle_t get(int device_id) {
yuguo's avatar
yuguo committed
471
472
    std::lock_guard<std::mutex> lock(mt);

wenjh's avatar
wenjh committed
473
474
    if (pool.empty()) {
      int device_count = 0;
yuguo's avatar
yuguo committed
475
476
477
478
479
      NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
      pool.resize(device_count);
      return nullptr;
    }

wenjh's avatar
wenjh committed
480
    if (!pool[device_id].empty()) {
yuguo's avatar
yuguo committed
481
482
483
484
485
486
487
488
      hipblasLtHandle_t h = pool[device_id].front();
      pool[device_id].pop_front();
      return h;
    }

    return nullptr;
  }

wenjh's avatar
wenjh committed
489
  hipblasLtHandle_t obtain(int device_id) {
yuguo's avatar
yuguo committed
490
    hipblasLtHandle_t h = get(device_id);
wenjh's avatar
wenjh committed
491
    if (h == nullptr) {
yuguo's avatar
yuguo committed
492
493
494
495
496
      NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
    }
    return h;
  }

wenjh's avatar
wenjh committed
497
  void store(const std::vector<hipblasLtHandle_t>& handles) {
yuguo's avatar
yuguo committed
498
    std::lock_guard<std::mutex> lock(mt);
wenjh's avatar
wenjh committed
499
    if (pool.empty()) {
yuguo's avatar
yuguo committed
500
501
      std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
    }
wenjh's avatar
wenjh committed
502
503
    for (unsigned int i = 0; i < pool.size(); i++) {
      if (handles[i] != nullptr) {
yuguo's avatar
yuguo committed
504
505
506
507
508
509
510
511
        pool[i].push_front(handles[i]);
      }
    }
  }

  ~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL
    std::lock_guard<std::mutex> lock(mt);
wenjh's avatar
wenjh committed
512
513
    for (auto& hlist : pool) {
      for (auto& h : hlist) {
yuguo's avatar
yuguo committed
514
515
516
517
518
519
520
        hipblasLtDestroy(h);
      }
    }
    pool.clear();
#endif
  }

wenjh's avatar
wenjh committed
521
  inline size_t get_size() const { return pool.size(); }
yuguo's avatar
yuguo committed
522

wenjh's avatar
wenjh committed
523
 private:
yuguo's avatar
yuguo committed
524
525
526
527
528
529
530
531
532
  std::mutex mt;
  using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
  // Order of destructors between thread_local and global is not actually guaranteed
  // As a simple w/a make pool storage "leaky"
  // Just do not destruct it and do not destroy hipbladLt handles
  // Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
  Pool pool;
#else
wenjh's avatar
wenjh committed
533
  Pool& pool = *new Pool();
yuguo's avatar
yuguo committed
534
535
536
537
#endif
} handle_pool;

thread_local static class HandleCache {
wenjh's avatar
wenjh committed
538
539
 public:
  hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; }
yuguo's avatar
yuguo committed
540

wenjh's avatar
wenjh committed
541
  hipblasLtHandle_t obtain(int device_id) {
yuguo's avatar
yuguo committed
542
    hipblasLtHandle_t h = get(device_id);
wenjh's avatar
wenjh committed
543
    if (h) {
yuguo's avatar
yuguo committed
544
545
546
547
548
549
550
      return h;
    }
    h = handle_pool.obtain(device_id);
    set(device_id, h);
    return h;
  }

wenjh's avatar
wenjh committed
551
552
  void set(int device_id, hipblasLtHandle_t h) {
    if (d.empty()) {
yuguo's avatar
yuguo committed
553
554
555
556
557
      d.resize(handle_pool.get_size());
    }
    d[device_id] = h;
  }

wenjh's avatar
wenjh committed
558
559
  ~HandleCache() {
    if (!d.empty()) {
yuguo's avatar
yuguo committed
560
561
562
563
      handle_pool.store(d);
    }
  }

wenjh's avatar
wenjh committed
564
 private:
yuguo's avatar
yuguo committed
565
566
567
  std::vector<hipblasLtHandle_t> d;
} cached_handles;

wenjh's avatar
wenjh committed
568
569
class csv_helper {
 public:
yuguo's avatar
yuguo committed
570
571
572
  struct start {};
  struct end {};

wenjh's avatar
wenjh committed
573
574
  csv_helper(std::ostream& os, char sep_val)
      : m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {}
yuguo's avatar
yuguo committed
575

wenjh's avatar
wenjh committed
576
  csv_helper& operator<<(const start&) {
yuguo's avatar
yuguo committed
577
578
579
580
    m_start = true;
    return *this;
  }

wenjh's avatar
wenjh committed
581
582
  csv_helper& operator<<(const end&) {
    m_sep = "";
yuguo's avatar
yuguo committed
583
584
585
586
    m_start = false;
    return *this;
  }

wenjh's avatar
wenjh committed
587
588
  template <typename T>
  csv_helper& operator<<(const T& v) {
yuguo's avatar
yuguo committed
589
    m_os << m_sep << v;
wenjh's avatar
wenjh committed
590
    if (m_start) {
yuguo's avatar
yuguo committed
591
592
593
594
595
596
      m_start = false;
      m_sep = m_sep_val;
    }
    return *this;
  }

wenjh's avatar
wenjh committed
597
 private:
yuguo's avatar
yuguo committed
598
599
600
601
602
603
  std::ostream& m_os;
  char m_sep_val;
  bool m_start;
  std::string m_sep;
};

wenjh's avatar
wenjh committed
604
605
606
607
608
609
610
611
template <typename T>
class NameMapper {
 public:
  NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {}
  const std::string_view& getName(const T& val) { return map.at(val); }
  T getValue(const std::string& name, const char* label = "",
             std::function<bool(const T&)> filter = nullptr) {
    for (auto iter = map.begin(); iter != map.end(); ++iter) {
yuguo's avatar
yuguo committed
612
      if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
yuguo's avatar
yuguo committed
613
614
615
    }
    NVTE_ERROR("Invalid ", label, " name: ", name);
  }
wenjh's avatar
wenjh committed
616
617
618

 protected:
  const std::unordered_map<T, std::string_view>& map;
yuguo's avatar
yuguo committed
619
620
};

yuguo's avatar
yuguo committed
621
static std::unordered_map<hipDataType, std::string_view> type_name_map = {
wenjh's avatar
wenjh committed
622
623
624
625
626
    {HIP_R_32F, "float32"},
    {HIP_R_16F, "float16"},
    {HIP_R_16BF, "bfloat16"},
    {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
    {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
yuguo's avatar
yuguo committed
627
#if HIP_VERSION >= 60300000
wenjh's avatar
wenjh committed
628
629
    {HIP_R_8F_E4M3, "float8e4m3"},
    {HIP_R_8F_E5M2, "float8e5m2"},
yuguo's avatar
yuguo committed
630
#endif
yuguo's avatar
yuguo committed
631
};
yuguo's avatar
yuguo committed
632
static NameMapper<hipDataType> typeNameMapper(type_name_map);
yuguo's avatar
yuguo committed
633
634

static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
wenjh's avatar
wenjh committed
635
    {HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}};
yuguo's avatar
yuguo committed
636
637
638
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map);

static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = {
wenjh's avatar
wenjh committed
639
640
641
642
    {HIPBLASLT_EPILOGUE_DEFAULT, "-"},        {HIPBLASLT_EPILOGUE_BIAS, "bias"},
    {HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"},
    {HIPBLASLT_EPILOGUE_DGELU, "dgelu"},      {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
    {HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}};
yuguo's avatar
yuguo committed
643
644
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);

yuguo's avatar
yuguo committed
645
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
wenjh's avatar
wenjh committed
646
    {HIPBLAS_COMPUTE_32F, "f32"}};
yuguo's avatar
yuguo committed
647
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
yuguo's avatar
yuguo committed
648
649

static class GemmAlgoCache {
wenjh's avatar
wenjh committed
650
 public:
yuguo's avatar
yuguo committed
651
652
  struct Key {
    int deviceCap;
yuguo's avatar
yuguo committed
653
    hipDataType a_type, b_type, d_type, bias_type;
yuguo's avatar
yuguo committed
654
655
656
657
658
    int m, n, k;
    int lda, ldb, ldd;
    hipblasOperation_t transa, transb;
    hipblasLtEpilogue_t epilogue;

wenjh's avatar
wenjh committed
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_,
        hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
        hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_)
        : deviceCap(deviceCap_),
          a_type(a_type_),
          b_type(b_type_),
          d_type(d_type_),
          bias_type(bias_type_),
          m(m_),
          n(n_),
          k(k_),
          lda(lda_),
          ldb(ldb_),
          ldd(ldd_),
          transa(transa_),
          transb(transb_),
          epilogue(epilogue_) {}
yuguo's avatar
yuguo committed
676
677
678

    Key() {}

wenjh's avatar
wenjh committed
679
680
681
682
683
684
    bool operator==(const Key& val) const {
      return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) &&
              (d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) &&
              (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) &&
              (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) &&
              (epilogue == val.epilogue));
yuguo's avatar
yuguo committed
685
686
    }

wenjh's avatar
wenjh committed
687
688
689
690
    struct Comp {
      bool operator()(const Key& lhs, const Key& rhs) const {
        return ::std::string_view((const char*)&lhs, sizeof(lhs)) <
               ::std::string_view((const char*)&rhs, sizeof(rhs));
yuguo's avatar
yuguo committed
691
692
693
694
      }
    };
  };

wenjh's avatar
wenjh committed
695
  void init() {
yuguo's avatar
yuguo committed
696
    std::lock_guard<std::mutex> lock(mt);
wenjh's avatar
wenjh committed
697
    int device_count = 0;
yuguo's avatar
yuguo committed
698
699
    NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
    dev_cap.resize(device_count);
wenjh's avatar
wenjh committed
700
    for (int i = 0; i < device_count; i++) {
yuguo's avatar
yuguo committed
701
702
      hipDeviceProp_t prop;
      NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i));
wenjh's avatar
wenjh committed
703
      dev_cap[i] = prop.major * 100 + prop.minor;
yuguo's avatar
yuguo committed
704
705
706
707
708
    }
    load_();
    save_();
  }

wenjh's avatar
wenjh committed
709
710
  inline int device_cap(int device_id) {
    if (dev_cap.empty()) init();
yuguo's avatar
yuguo committed
711
712
713
714
715
716
717
718
719
    return dev_cap[device_id];
  }

  struct Algo {
    std::optional<hipblasLtMatmulAlgo_t> algo;
    int64_t algoId;
    int index;
    size_t ws_size_min;
    size_t ws_size_max;
wenjh's avatar
wenjh committed
720
721
722
723
724
    Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {}
    Algo(int idx, int64_t id, size_t ws_min, size_t ws_max)
        : algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {}
    inline bool hasId() { return index >= 0; }
    const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) {
yuguo's avatar
yuguo committed
725
726
727
728
      return *(const int64_t*)&algo;
    }
  };

wenjh's avatar
wenjh committed
729
  bool find(const Key& cfg, size_t ws_size, Algo& algo) {
yuguo's avatar
yuguo committed
730
    std::lock_guard<std::mutex> lock(mt);
wenjh's avatar
wenjh committed
731
    if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) {
yuguo's avatar
yuguo committed
732
733
734
735
736
737
      algo = *pentry;
      return true;
    }
    return false;
  }

wenjh's avatar
wenjh committed
738
  void store(const Key& cfg, const Algo& algo) {
yuguo's avatar
yuguo committed
739
740
741
742
743
744
745
    size_t ws_size_min = algo.ws_size_min;
    size_t ws_size_max = algo.ws_size_max;
    NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size");
    std::lock_guard<std::mutex> lock(mt);

    //Remove overlapping with existing entries;
    while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) {
wenjh's avatar
wenjh committed
746
      if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) {
yuguo's avatar
yuguo committed
747
748
749
750
751
        *pentry = algo;
        save_();
        return;
      }

wenjh's avatar
wenjh committed
752
      if (ws_size_max > pentry->ws_size_max) {
yuguo's avatar
yuguo committed
753
        ws_size_min = pentry->ws_size_max + 1;
wenjh's avatar
wenjh committed
754
      } else if (ws_size_min < pentry->ws_size_min) {
yuguo's avatar
yuguo committed
755
        ws_size_max = pentry->ws_size_min - 1;
wenjh's avatar
wenjh committed
756
      } else {
yuguo's avatar
yuguo committed
757
758
759
760
761
762
763
        //Should never be here
        NVTE_ERROR("Cannot merge WS size range");
      }
    }

    //Merge to adjusted entry if possible
    auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min);
wenjh's avatar
wenjh committed
764
    if (pentry && pentry->algoId == algo.algoId) {
yuguo's avatar
yuguo committed
765
766
767
      pentry->algo = algo.algo;
      pentry->ws_size_max = ws_size_max;
      save_();
wenjh's avatar
wenjh committed
768
    } else {
yuguo's avatar
yuguo committed
769
770
771
772
773
774
775
      auto it = d.emplace(cfg, algo);
      it->second.ws_size_min = ws_size_min;
      it->second.ws_size_max = ws_size_max;
      save_(it->first, it->second);
    }
  }

wenjh's avatar
wenjh committed
776
777
 protected:
  Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) {
yuguo's avatar
yuguo committed
778
    const auto key_range = d.equal_range(cfg);
wenjh's avatar
wenjh committed
779
780
    for (auto i = key_range.first; i != key_range.second; i++) {
      if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) {
yuguo's avatar
yuguo committed
781
782
783
784
785
786
        return &i->second;
      }
    }
    return nullptr;
  }

wenjh's avatar
wenjh committed
787
  void header_(std::ostream& ofs) {
yuguo's avatar
yuguo committed
788
    csv_helper fs(ofs, csv_sep);
wenjh's avatar
wenjh committed
789
790
791
792
793
794
795
    fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
       << "type_a" << "type_b" << "type_d" << "bias_type"
       << "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
       << "ws_min" << "ws_max" << "algo_id" << "aidx";
  }

  void load_() {
yuguo's avatar
yuguo committed
796
    const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
wenjh's avatar
wenjh committed
797
    if (env == nullptr || env[0] == '\0') {
yuguo's avatar
yuguo committed
798
799
800
      return;
    }
    std::ifstream ifs{env};
wenjh's avatar
wenjh committed
801
    if (!ifs.is_open()) {
yuguo's avatar
yuguo committed
802
803
804
805
806
807
808
      std::cerr << "Could not load autotune results storage " << env << "\n";
      return;
    }
    std::cout << "Loading autotune results from " << env << "\n";

    Key cfg;
    std::string line;
wenjh's avatar
wenjh committed
809
    std::getline(ifs, line);  // the first line with legend
yuguo's avatar
yuguo committed
810
811
812
813
814
815
816
817
818
    {
      std::ostringstream hline;
      header_(hline);
      if (hline.str() != line) {
        std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n";
        return;
      }
    }

wenjh's avatar
wenjh committed
819
    while (std::getline(ifs, line)) {
yuguo's avatar
yuguo committed
820
      line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
wenjh's avatar
wenjh committed
821
822
      if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) {
        line.resize(pos + 1);
yuguo's avatar
yuguo committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
      }
      if (line.empty() || line[0] == '#') continue;
      std::istringstream is(line);
      char c;
      std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
      int64_t algo_id;
      int algo_idx;
      size_t ws_min, ws_max;

      is >> std::skipws;
      is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c;

      //Filter out entries for devices not presented on the curent system
      bool b_found = false;
wenjh's avatar
wenjh committed
837
838
      for (int i = 0; i < dev_cap.size(); i++) {
        if (dev_cap[i] == cfg.deviceCap) {
yuguo's avatar
yuguo committed
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
          b_found = true;
          break;
        }
      }
      if (!b_found) continue;

      std::getline(is, trans_a, csv_sep);
      std::getline(is, trans_b, csv_sep);
      std::getline(is, type_a, csv_sep);
      std::getline(is, type_b, csv_sep);
      std::getline(is, type_d, csv_sep);
      std::getline(is, bias_type, csv_sep);
      is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c;
      std::getline(is, epi, csv_sep);
      std::getline(is, comp, csv_sep);
      std::getline(is, scale, csv_sep);
      is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx;
wenjh's avatar
wenjh committed
856
857

      if (is.bad()) {
yuguo's avatar
yuguo committed
858
859
860
861
        std::cerr << "Parsing CSV line failed: " << line << "\n";
        return;
      }

wenjh's avatar
wenjh committed
862
      if (ws_min > ws_max) {
yuguo's avatar
yuguo committed
863
864
865
        std::cout << "[WARNING] Invalid WS size at " << line << "\n";
        continue;
      }
yuguo's avatar
yuguo committed
866
867

#if HIP_VERSION >= 60300000
wenjh's avatar
wenjh committed
868
      auto fp8_filter = [](const hipDataType& val) {
wenjh's avatar
wenjh committed
869
870
        return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
      };
yuguo's avatar
yuguo committed
871
872
873
874
875
876
877
878
879
880
#else
      auto fp8_filter = nullptr;
#endif

      cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
      cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
      cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
      cfg.bias_type = (bias_type == "-")
                          ? (hipDataType)-1
                          : typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
yuguo's avatar
yuguo committed
881
882
883
884
885
886

      cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
      cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");

      cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
      //Check and filter out compute and scale types
yuguo's avatar
yuguo committed
887
      if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
wenjh's avatar
wenjh committed
888
          typeNameMapper.getValue(scale, "scale") != HIP_R_32F) {
yuguo's avatar
yuguo committed
889
890
891
        continue;
      }

wenjh's avatar
wenjh committed
892
893
894
      if (find_(cfg, ws_min, ws_max)) {
        std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
        continue;
yuguo's avatar
yuguo committed
895
896
897
898
899
900
      }

      d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max));
    }
  }

wenjh's avatar
wenjh committed
901
902
  bool can_save_(bool reopen = false) {
    if (!save_fs) {
yuguo's avatar
yuguo committed
903
      const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE");
wenjh's avatar
wenjh committed
904
      if (temp == nullptr || temp[0] == '\0') {
yuguo's avatar
yuguo committed
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        return false;
      }

      save_fs_name = temp;

      pid_t pid = getpid();

      size_t pos = 0;
      while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) {
        save_fs_name.replace(pos, 2, std::to_string(pid));
      }

      save_fs = std::make_unique<std::ofstream>();
      std::cout << "Saving autotune results to " << save_fs_name << "\n";
    }

wenjh's avatar
wenjh committed
921
922
    if (reopen) {
      if (save_fs->is_open()) {
yuguo's avatar
yuguo committed
923
924
925
926
927
        save_fs->close();
      }
      save_fs->open(save_fs_name, std::ios_base::trunc);
    }

wenjh's avatar
wenjh committed
928
    if (save_fs->is_open() && !save_fs->bad()) {
yuguo's avatar
yuguo committed
929
      return true;
wenjh's avatar
wenjh committed
930
    } else {
yuguo's avatar
yuguo committed
931
932
933
934
935
      if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n";
      return false;
    }
  }

wenjh's avatar
wenjh committed
936
937
  void save_() {
    if (!can_save_(true)) {
yuguo's avatar
yuguo committed
938
939
940
941
942
      return;
    }
    header_(*save_fs);
    *save_fs << "\n";

wenjh's avatar
wenjh committed
943
    for (const auto& elem : d) {
yuguo's avatar
yuguo committed
944
945
946
947
      save_(elem.first, elem.second);
    }
  }

wenjh's avatar
wenjh committed
948
949
  void save_(const Key& cfg, const Algo& algo) {
    if (!can_save_()) {
yuguo's avatar
yuguo committed
950
951
952
      return;
    }
    csv_helper csv(*save_fs, csv_sep);
wenjh's avatar
wenjh committed
953
954
955
956
957
958
959
960
    csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa)
        << transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type)
        << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
        << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
        << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
        << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
        << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end()
        << "\n";
yuguo's avatar
yuguo committed
961
962
  }

wenjh's avatar
wenjh committed
963
 private:
yuguo's avatar
yuguo committed
964
  std::vector<int> dev_cap;
wenjh's avatar
wenjh committed
965
  constexpr static char csv_sep = ',';
yuguo's avatar
yuguo committed
966
967
968
969
970
971
972
973
974
975
  std::unique_ptr<std::ofstream> save_fs;
  std::string save_fs_name;
  std::mutex mt;
  /* Map of problem config to tuple of ws_size and Algo
   * When searching, elements matching Key are filtered 
   * for requested WS size be between Algo.ws_size and pair.first
   */
  std::multimap<Key, Algo, Key::Comp> d;
} algoCache;

wenjh's avatar
wenjh committed
976
static inline int getIntEnv(const char* name, int defval, int minval) {
yuguo's avatar
yuguo committed
977
978
  int val = defval;
  const char* env = std::getenv(name);
wenjh's avatar
wenjh committed
979
980
981
982
983
  if (env != nullptr && env[0] != '\0') {
    val = atoi(env);
    if (val < minval) {
      val = minval;
    }
yuguo's avatar
yuguo committed
984
985
986
987
  }
  return val;
}

wenjh's avatar
wenjh committed
988
}  //namespace
yuguo's avatar
yuguo committed
989
990
991
992
993
994
995
996

/* Warning: only call once per device!
 * When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
 * need to create multiple handles corresponding to compute_streams
 * to avoid a handle be used by multi-streams concurrently.
 */
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
  NVTE_CHECK(hipblaslt_handles != nullptr);
yuguo's avatar
yuguo committed
997
  for (int i = 0; i < compute_num_streams; i++) {
yuguo's avatar
yuguo committed
998
999
1000
1001
    NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
  }
}

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
  using namespace transformer_engine;
  switch (t) {
    case HIP_R_16F:
      return DType::kFloat16;
    case HIP_R_32F:
      return DType::kFloat32;
    case HIP_R_16BF:
      return DType::kBFloat16;
    default:
      NVTE_ERROR("Invalid type");
  }
}

wenjh's avatar
wenjh committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
                    const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
                    int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
                    bool grad, void* workspace, size_t workspaceSize, bool accumulate,
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, const Tensor* inputCounter, hipStream_t stream,
                    hipblasLtHandle_t handle) {
  void* A = inputA->data.dptr;
  void* A_scale_inverse = inputA->scale_inv.dptr;
yuguo's avatar
yuguo committed
1025
  float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
wenjh's avatar
wenjh committed
1026
1027
  void* B = inputB->data.dptr;
  void* B_scale_inverse = inputB->scale_inv.dptr;
yuguo's avatar
yuguo committed
1028
  float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr);
wenjh's avatar
wenjh committed
1029
1030
  void* D = outputD->data.dptr;
  void* bias_ptr = inputBias->data.dptr;
yuguo's avatar
yuguo committed
1031
  const bool bias = bias_ptr != nullptr;
wenjh's avatar
wenjh committed
1032
  void* pre_gelu_out = outputPreGelu->data.dptr;
yuguo's avatar
yuguo committed
1033
  const bool gelu = pre_gelu_out != nullptr;
wenjh's avatar
wenjh committed
1034
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
1035
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
1036
1037
1038
1039
  const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
  const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
  const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
  const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
yuguo's avatar
yuguo committed
1040
1041
1042
1043
1044

  NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
yuguo's avatar
yuguo committed
1045
1046
1047
1048
1049
1050
1051
1052
  NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "INT8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "INT8 input to GEMM requires inverse of scale!");

  bool tensorwise_int8 = 0;;
  const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;           
yuguo's avatar
yuguo committed
1053
1054
1055
1056

  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
yuguo's avatar
yuguo committed
1057
  if (use_fp8 || use_int8) {
yuguo's avatar
yuguo committed
1058
1059
1060
1061
1062
1063
1064
1065
1066
    NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
  }
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;

  int device_id;
  NVTE_CHECK_CUDA(hipGetDevice(&device_id));

yuguo's avatar
yuguo committed
1067
1068
  if (handle == nullptr) {
    handle = cached_handles.get(device_id);
wenjh's avatar
wenjh committed
1069
    if (handle == nullptr) {
yuguo's avatar
yuguo committed
1070
1071
      handle = cached_handles.obtain(device_id);
    }
yuguo's avatar
yuguo committed
1072
1073
  }

wenjh's avatar
wenjh committed
1074
1075
  hipblasLtMatmulDesc_t operationDesc = nullptr;
  hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
yuguo's avatar
yuguo committed
1076
1077
1078
  hipblasLtMatmulPreference_t preference = nullptr;
  hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;

wenjh's avatar
wenjh committed
1079
  int64_t ld_gelumat = (int64_t)ldd;
yuguo's avatar
yuguo committed
1080
1081

  // default to tf32 except for e5m2 inputs where the config is not supported
yuguo's avatar
yuguo committed
1082
  hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
yuguo's avatar
yuguo committed
1083
1084

  // Create matrix descriptors. Not setting any extra attributes.
wenjh's avatar
wenjh committed
1085
1086
1087
1088
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k,
                                                   transa == HIPBLAS_OP_N ? k : m, lda));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n,
                                                   transb == HIPBLAS_OP_N ? n : k, ldb));
yuguo's avatar
yuguo committed
1089
1090
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));

wenjh's avatar
wenjh committed
1091
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
yuguo's avatar
yuguo committed
1092
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
wenjh's avatar
wenjh committed
1093
                                                       &transa, sizeof(transa)));
yuguo's avatar
yuguo committed
1094
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
wenjh's avatar
wenjh committed
1095
                                                       &transb, sizeof(transb)));
yuguo's avatar
yuguo committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
    /*
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
                                                     &fastAccuMode,
                                                     sizeof(fastAccuMode)));
    */
wenjh's avatar
wenjh committed
1109
1110
1111
1112
1113
1114
    NVTE_CHECK_HIPBLASLT(
        hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                        &A_scale_inverse, sizeof(A_scale_inverse)));
    NVTE_CHECK_HIPBLASLT(
        hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                        &B_scale_inverse, sizeof(B_scale_inverse)));
yuguo's avatar
yuguo committed
1115
    if (bias) {
wenjh's avatar
wenjh committed
1116
1117
      NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
          operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
yuguo's avatar
yuguo committed
1118
1119
    }
  }
yuguo's avatar
yuguo committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
  if (tensorwise_int8) {
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                     (void*)&A_scale_inverse_float,
                                                     sizeof(void*)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                     (void*)&B_scale_inverse_float,
                                                     sizeof(void*)));
  }
yuguo's avatar
yuguo committed
1130
1131
1132
1133
1134
1135
1136
1137

  if (bias && gelu) {
    if (grad) {
      epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
wenjh's avatar
wenjh committed
1138
        operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
yuguo's avatar
yuguo committed
1139
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
wenjh's avatar
wenjh committed
1140
1141
1142
1143
                                                         HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                         &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
        operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
yuguo's avatar
yuguo committed
1144
  } else if (bias) {
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
    if (tensorwise_int8) {
      if (grad) {
        int batch_size = k;
        int output_dim = n;
        DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type);
        TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
          te_bias_dtype, BType,·
          detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
            reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size,
            output_dim, stream););
      } else {
        NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
          operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
        epilogue = HIPBLASLT_EPILOGUE_BIAS;
        NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
          operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
      }
yuguo's avatar
yuguo committed
1162
    } else {
1163
1164
1165
1166
1167
1168
1169
1170
      if (grad) {
        // grad output is always input B
        epilogue = HIPBLASLT_EPILOGUE_BGRADB;
      } else {
        epilogue = HIPBLASLT_EPILOGUE_BIAS;
      }
      NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
          operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
yuguo's avatar
yuguo committed
1171
    }
1172
    
yuguo's avatar
yuguo committed
1173
1174
1175
1176
1177
1178
1179
  } else if (gelu) {
    if (grad) {
      epilogue = HIPBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
wenjh's avatar
wenjh committed
1180
1181
1182
1183
                                                         HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                                                         &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
        operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
yuguo's avatar
yuguo committed
1184
1185
  }

wenjh's avatar
wenjh committed
1186
1187
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
      operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
yuguo's avatar
yuguo committed
1188

wenjh's avatar
wenjh committed
1189
1190
1191
  GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
                              use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa,
                              transb, epilogue);
yuguo's avatar
yuguo committed
1192
  GemmAlgoCache::Algo cached_algo;
wenjh's avatar
wenjh committed
1193
  if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) {
yuguo's avatar
yuguo committed
1194
1195
1196
1197
1198
1199
    int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
    int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
    int algoTuneCount = 1;
    std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
    bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;

wenjh's avatar
wenjh committed
1200
    if (tuneLoopCount) {
yuguo's avatar
yuguo committed
1201
1202
1203
1204
1205
1206
1207
      /* HIPBLASLT may return hundreds of algos for some configs
       * Limit amount by default. User may override with env
       */
      static const int defaultAlgoCount = 16;
      algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
    }
    algoTuneCount += firstAlgo;
wenjh's avatar
wenjh committed
1208
1209
    int algoTotalCount =
        cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
yuguo's avatar
yuguo committed
1210
1211
1212
    algoArr.resize(algoTotalCount);

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
wenjh's avatar
wenjh committed
1213
1214
1215
    NVTE_CHECK_HIPBLASLT(
        hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                              &workspaceSize, sizeof(workspaceSize)));
yuguo's avatar
yuguo committed
1216
1217

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
wenjh's avatar
wenjh committed
1218
1219
                                                         Ddesc, preference, algoTotalCount,
                                                         algoArr.data(), &algoTotalCount));
yuguo's avatar
yuguo committed
1220
1221
1222
1223
1224
    algoArr.resize(algoTotalCount);

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));

    //If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
wenjh's avatar
wenjh committed
1225
    if (cached_algo.hasId()) {
yuguo's avatar
yuguo committed
1226
      int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
wenjh's avatar
wenjh committed
1227
1228
1229
1230
      for (int i = 0; i < algoTotalCount; i++) {
        const auto& algo = algoArr[idx];
        if (algo.state == HIPBLAS_STATUS_SUCCESS) {
          if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) {
yuguo's avatar
yuguo committed
1231
            cached_algo.algo = algo.algo;
wenjh's avatar
wenjh committed
1232
            if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) {
yuguo's avatar
yuguo committed
1233
1234
1235
1236
1237
1238
1239
1240
1241
              cached_algo.ws_size_min = algo.workspaceSize;
              cached_algo.index = idx;
              algoCache.store(gemm_cfg, cached_algo);
            }
            break;
          }
        }
        idx = (idx + 1) % algoTotalCount;
      }
wenjh's avatar
wenjh committed
1242
1243
1244
      if (logTuning && !cached_algo.algo.has_value()) {
        std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId
                  << " in hipBLASLt results" << std::endl;
yuguo's avatar
yuguo committed
1245
1246
1247
1248
      }
    }

    //No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
wenjh's avatar
wenjh committed
1249
    if (!cached_algo.algo.has_value()) {
yuguo's avatar
yuguo committed
1250
1251
      int bestAlgo = -1;
      algoTuneCount = std::min(algoTuneCount, algoTotalCount);
wenjh's avatar
wenjh committed
1252
      if (tuneLoopCount > 0) {
yuguo's avatar
yuguo committed
1253
1254
1255
1256
1257
        if (logTuning)
          std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
                    << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
                    << tuneLoopCount << " loops " << std::endl;

yuguo's avatar
yuguo committed
1258
        NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
yuguo's avatar
yuguo committed
1259
1260
1261
        hipStream_t profilingStream;
        NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
        using tuning_clock = std::chrono::steady_clock;
wenjh's avatar
wenjh committed
1262
        tuning_clock::now();  //the first call takes little longer so do it outside the loop
yuguo's avatar
yuguo committed
1263
1264
        tuning_clock::duration bestTime = tuning_clock::duration::max();

wenjh's avatar
wenjh committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        for (int algo = firstAlgo; algo < algoTuneCount; algo++) {
          if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) {
            continue;
          }
          // Warm-up call
          NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
                                               static_cast<const void*>(&one),         /* alpha */
                                               A,                                      /* A */
                                               Adesc, B,                               /* B */
                                               Bdesc, static_cast<const void*>(&beta), /* beta */
                                               D,                                      /* C */
                                               Ddesc, D,                               /* D */
                                               Ddesc, &algoArr[algo].algo,             /* algo */
                                               workspace,                        /* workspace */
                                               workspaceSize, profilingStream)); /* stream */
yuguo's avatar
yuguo committed
1280
1281
1282
1283
          NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));

          //Profiling loop
          tuning_clock::time_point startTime = tuning_clock::now();
wenjh's avatar
wenjh committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
          for (int loop = 0; loop < tuneLoopCount; loop++) {
            NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
                                                 static_cast<const void*>(&one),         /* alpha */
                                                 A,                                      /* A */
                                                 Adesc, B,                               /* B */
                                                 Bdesc, static_cast<const void*>(&beta), /* beta */
                                                 D,                                      /* C */
                                                 Ddesc, D,                               /* D */
                                                 Ddesc, &algoArr[algo].algo,             /* algo */
                                                 workspace,                        /* workspace */
                                                 workspaceSize, profilingStream)); /* stream */
yuguo's avatar
yuguo committed
1295
1296
          }
          NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
wenjh's avatar
wenjh committed
1297
1298
          tuning_clock::duration algoTime = tuning_clock::now() - startTime;
          if (algoTime < bestTime) {
yuguo's avatar
yuguo committed
1299
1300
1301
1302
1303
1304
            bestAlgo = algo;
            bestTime = algoTime;
          }
        }

        NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
wenjh's avatar
wenjh committed
1305
        if (bestAlgo >= 0) {
yuguo's avatar
yuguo committed
1306
1307
          if (logTuning)
            std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
wenjh's avatar
wenjh committed
1308
1309
                      << std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() /
                             tuneLoopCount
yuguo's avatar
yuguo committed
1310
1311
                      << " ns" << std::endl;
        }
wenjh's avatar
wenjh committed
1312
      } else if (firstAlgo < algoTuneCount) {
yuguo's avatar
yuguo committed
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
        bestAlgo = firstAlgo;
      }

      if (bestAlgo < 0) {
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
        throw std::runtime_error("Unable to find any suitable algorithms");
      }
      cached_algo.algo = algoArr[bestAlgo].algo;
      cached_algo.index = bestAlgo;
      cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
      cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
      cached_algo.ws_size_max = workspaceSize;

      if (logTuning)
wenjh's avatar
wenjh committed
1330
1331
        std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId
                  << std::endl;
yuguo's avatar
yuguo committed
1332
1333
1334
1335
1336
1337

      algoCache.store(gemm_cfg, cached_algo);
    }
  }

  // D = alpha * (A * B) + beta * C
wenjh's avatar
wenjh committed
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
                                       static_cast<const void*>(&one),         /* alpha */
                                       A,                                      /* A */
                                       Adesc, B,                               /* B */
                                       Bdesc, static_cast<const void*>(&beta), /* beta */
                                       D,                                      /* C */
                                       Ddesc, D,                               /* D */
                                       Ddesc, &cached_algo.algo.value(),       /* algo */
                                       workspace,                              /* workspace */
                                       workspaceSize, stream));                /* stream */
yuguo's avatar
yuguo committed
1348
1349
1350
1351
1352
1353

  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
1354

yuguo's avatar
yuguo committed
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
                 const Tensor *inputB,
                 const Tensor *inputA_scales,
                 const Tensor *inputB_scales,
                 Tensor *outputD,
                 const Tensor *inputBias,
                 Tensor *outputPreGelu,
                 int m, int n, int k,
                 int lda, int ldb, int ldd,
                 hipblasOperation_t transa,
                 hipblasOperation_t transb,
                 bool grad,
                 void* workspace,
                 size_t workspaceSize,
                 bool accumulate,
                 bool use_split_accumulator,
                 int math_sm_count,
                 int m_split,
                 int n_split,
                 bool gemm_producer,
                 const Tensor *inputCounter,
                 size_t batch_count,
                 hipStream_t stream,
                 hipblasLtHandle_t handle
) {
  void *A = inputA->data.dptr;
  void *A_scale_inverse = inputA_scales->data.dptr;
  float *A_scale_inverse_float = (float*)(inputA_scales->data.dptr);
  void *B = inputB->data.dptr;
  void *B_scale_inverse = inputB_scales->data.dptr;
  float *B_scale_inverse_float = (float*)(inputB_scales->data.dptr);
  void *D = outputD->data.dptr;
  void *bias_ptr = inputBias->data.dptr;
  const bool bias = bias_ptr != nullptr;
  void *pre_gelu_out = outputPreGelu->data.dptr;
  const bool gelu = pre_gelu_out != nullptr;
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
                       is_fp8_dtype(inputB->data.dtype);
  const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
                        is_int8_dtype(inputB->data.dtype);
  const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
  const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
  const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
  const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);

  NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "FP8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
             "INT8 input to GEMM requires inverse of scale!");
  NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
             "INT8 input to GEMM requires inverse of scale!");

  bool tensorwise_int8 = 0;;
  const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");      
  if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;           

  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 || use_int8) {
    NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
  }
  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;

  int device_id;
  NVTE_CHECK_CUDA(hipGetDevice(&device_id));

  if (handle == nullptr) {
    handle = cached_handles.get(device_id);
    if (handle == nullptr)
    {
      handle = cached_handles.obtain(device_id);
    }
  }

  hipblasLtMatmulDesc_t       operationDesc = nullptr;
  hipblasLtMatrixLayout_t     Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
  hipblasLtMatmulPreference_t preference = nullptr;
  hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;

  int64_t ld_gelumat = (int64_t) ldd;

  // default to tf32 except for e5m2 inputs where the config is not supported
  hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;

  // Create matrix descriptors. Not setting any extra attributes.
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
                                               transa == HIPBLAS_OP_N ? m : k,
                                               transa == HIPBLAS_OP_N ? k : m,
                                               lda));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type,
                                               transb == HIPBLAS_OP_N ? k : n,
                                               transb == HIPBLAS_OP_N ? n : k,
                                               ldb));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));

  if (tensorwise_int8) {
    size_t strideA = m*k;
    size_t strideB = k*n;
    size_t strideD = m*n;

    hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
    hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(int64_t));

    hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
    hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(int64_t));

    hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
    hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(int64_t));
    
wenjh's avatar
wenjh committed
1469
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
yuguo's avatar
yuguo committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
  } else {
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
  }
  
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
                                                   &transa, sizeof(transa)));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
                                                   &transb, sizeof(transb)));

  // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
  // Note: gelu fusion isn't available right now, and we don't need
  // amax(D) either (next op is high precision).
  if (use_fp8) {
    // Split accumulator.
    const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
    /*
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
                                                     &fastAccuMode,
                                                     sizeof(fastAccuMode)));
    */
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                     &A_scale_inverse,
                                                     sizeof(A_scale_inverse)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                     &B_scale_inverse,
                                                     sizeof(B_scale_inverse)));
    if (bias) {
      NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                       HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
                                                       &bias_type, sizeof(bias_type)));
    }
  }
  if (tensorwise_int8) {
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                     (void*)&A_scale_inverse_float,
                                                     sizeof(void*)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                     (void*)&B_scale_inverse_float,
                                                     sizeof(void*)));
    if (bias) {
      NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
    }
  }

  if (bias && gelu) {
    if (grad) {
      epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
    } else {
      epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
    }
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                      HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
                                                      &bias_ptr, sizeof(bias_ptr)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
                            operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                            &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                      HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
                                                      &ld_gelumat, sizeof(ld_gelumat)));
  } else if (bias) {
    if (grad) {
      // grad output is always input B
      epilogue = HIPBLASLT_EPILOGUE_BGRADB;
    } else {
      epilogue = HIPBLASLT_EPILOGUE_BIAS;
    }
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                      HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
                                                      &bias_ptr, sizeof(bias_ptr)));
  } else if (gelu) {
    if (grad) {
      epilogue = HIPBLASLT_EPILOGUE_DGELU;
    } else {
      epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
    }
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
                            operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                            &pre_gelu_out, sizeof(pre_gelu_out)));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                     HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
                                                     &ld_gelumat, sizeof(ld_gelumat)));
  }

  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
                                                   HIPBLASLT_MATMUL_DESC_EPILOGUE,
                                                   &epilogue, sizeof(epilogue)));

  GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, 
    use_fp8 ? bias_type : (hipDataType)-1,
    m, n, k, lda, ldb, ldd, transa, transb, epilogue );
  GemmAlgoCache::Algo cached_algo;
  if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
  {
    int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
    int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
    int algoTuneCount = 1;
    std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
    bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;

    if (tuneLoopCount)
    {
      /* HIPBLASLT may return hundreds of algos for some configs
       * Limit amount by default. User may override with env
       */
      static const int defaultAlgoCount = 16;
      algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
    }
    algoTuneCount += firstAlgo;
    int algoTotalCount = cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
    algoArr.resize(algoTotalCount);

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceSetAttribute(
                            preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                            &workspaceSize, sizeof(workspaceSize)));

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
                                                    Ddesc, preference, algoTotalCount, algoArr.data(),
                                                    &algoTotalCount));
    algoArr.resize(algoTotalCount);

    NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));

    //If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
    if (cached_algo.hasId())
    {
      int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
      for (int i=0; i<algoTotalCount; i++)
      {
        const auto &algo = algoArr[idx];
        if (algo.state == HIPBLAS_STATUS_SUCCESS)
        {
          if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo))
          {
            cached_algo.algo = algo.algo;
            if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index)
            {
              cached_algo.ws_size_min = algo.workspaceSize;
              cached_algo.index = idx;
              algoCache.store(gemm_cfg, cached_algo);
            }
            break;
          }
        }
        idx = (idx + 1) % algoTotalCount;
      }
      if (logTuning && !cached_algo.algo.has_value())
      {
        std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId << " in hipBLASLt results" << std::endl;
      }
    }

    //No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
    if (!cached_algo.algo.has_value())
    {

      int bestAlgo = -1;
      algoTuneCount = std::min(algoTuneCount, algoTotalCount);
      if (tuneLoopCount > 0)
      {
        if (logTuning)
          std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
                    << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
                    << tuneLoopCount << " loops " << std::endl;

        NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
        hipStream_t profilingStream;
        NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
        using tuning_clock = std::chrono::steady_clock;
        tuning_clock::now(); //the first call takes little longer so do it outside the loop
        tuning_clock::duration bestTime = tuning_clock::duration::max();

        for (int algo=firstAlgo; algo<algoTuneCount; algo++)
        {
            if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS)
            {
              continue;
            }
            // Warm-up call
            NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
                                            operationDesc,
                                            static_cast<const void*>(&one),         /* alpha */
                                            A,                                      /* A */
                                            Adesc,
                                            B,                                      /* B */
                                            Bdesc,
                                            static_cast<const void*>(&beta),        /* beta */
                                            D,                                      /* C */
                                            Ddesc,
                                            D,                                      /* D */
                                            Ddesc,
                                            &algoArr[algo].algo,                    /* algo */
                                            workspace,                              /* workspace */
                                            workspaceSize,
                                            profilingStream));                       /* stream */
          NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));

          //Profiling loop
          tuning_clock::time_point startTime = tuning_clock::now();
          for (int loop=0; loop<tuneLoopCount; loop++)
          {
            NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
                                            operationDesc,
                                            static_cast<const void*>(&one),         /* alpha */
                                            A,                                      /* A */
                                            Adesc,
                                            B,                                      /* B */
                                            Bdesc,
                                            static_cast<const void*>(&beta),        /* beta */
                                            D,                                      /* C */
                                            Ddesc,
                                            D,                                      /* D */
                                            Ddesc,
                                            &algoArr[algo].algo,                    /* algo */
                                            workspace,                              /* workspace */
                                            workspaceSize,
                                            profilingStream));                       /* stream */
          }
          NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
          tuning_clock::duration algoTime = tuning_clock::now() - startTime; 
          if (algoTime < bestTime)
          {
            bestAlgo = algo;
            bestTime = algoTime;
          }
        }

        NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
        if (bestAlgo >= 0)
        {
          if (logTuning)
            std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
                      << std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / tuneLoopCount
                      << " ns" << std::endl;
        }
      }
      else if (firstAlgo < algoTuneCount)
      {
        bestAlgo = firstAlgo;
      }

      if (bestAlgo < 0) {
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
        NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
        throw std::runtime_error("Unable to find any suitable algorithms");
      }
      cached_algo.algo = algoArr[bestAlgo].algo;
      cached_algo.index = bestAlgo;
      cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
      cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
      cached_algo.ws_size_max = workspaceSize;

      if (logTuning)
        std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId << std::endl;

      algoCache.store(gemm_cfg, cached_algo);
    }
  }

  // D = alpha * (A * B) + beta * C
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
                                   operationDesc,
                                   static_cast<const void*>(&one),         /* alpha */
                                   A,                                      /* A */
                                   Adesc,
                                   B,                                      /* B */
                                   Bdesc,
                                   static_cast<const void*>(&beta),        /* beta */
                                   D,                                      /* C */
                                   Ddesc,
                                   D,                                      /* D */
                                   Ddesc,
                                   &cached_algo.algo.value(),              /* algo */
                                   workspace,                              /* workspace */
                                   workspaceSize,
                                   stream));                               /* stream */


  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
  NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}

1761
class userArgsManager {
wenjh's avatar
wenjh committed
1762
1763
 public:
  userArgsManager() {}
1764

wenjh's avatar
wenjh committed
1765
1766
1767
1768
  ~userArgsManager() {
    // Release all userArgs when the manager is destroyed
    for (auto& device_pair : userArgs_map_) {
      hipFree(device_pair.second);  // Only one userArgs per device
1769
    }
wenjh's avatar
wenjh committed
1770
  }
1771

wenjh's avatar
wenjh committed
1772
1773
1774
  // Get a userArgs for the given device (creates if necessary)
  hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
    std::lock_guard<std::mutex> lock(mutex_);
1775

wenjh's avatar
wenjh committed
1776
1777
1778
1779
1780
    // Check if the userArgs for this device exists
    auto device_it = userArgs_map_.find(device_id);
    if (device_it != userArgs_map_.end()) {
      return device_it->second;
    }
1781

wenjh's avatar
wenjh committed
1782
1783
1784
    // Create a new userArgs for this device if it doesn't exist
    hipblaslt_ext::UserArguments* userArgs;
    NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
1785

wenjh's avatar
wenjh committed
1786
1787
1788
1789
    // Store the userArgs in the map for this device
    userArgs_map_[device_id] = userArgs;
    return userArgs;
  }
1790

wenjh's avatar
wenjh committed
1791
1792
1793
1794
 private:
  std::unordered_map<int, hipblaslt_ext::UserArguments*>
      userArgs_map_;  // Map from device_id to hipblasHandle
  std::mutex mutex_;
1795
1796
};

yuguo's avatar
yuguo committed
1797
class d_userArgsManager {
wenjh's avatar
wenjh committed
1798
1799
 public:
  d_userArgsManager() {}
yuguo's avatar
yuguo committed
1800

wenjh's avatar
wenjh committed
1801
1802
1803
1804
  ~d_userArgsManager() {
    // Release all userArgs when the manager is destroyed
    for (auto& device_pair : d_userArgs_map_) {
      hipFree(device_pair.second);  // Only one userArgs per device
yuguo's avatar
yuguo committed
1805
    }
wenjh's avatar
wenjh committed
1806
  }
yuguo's avatar
yuguo committed
1807

wenjh's avatar
wenjh committed
1808
1809
1810
  // Get a userArgs for the given device (creates if necessary)
  hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
    std::lock_guard<std::mutex> lock(mutex_);
yuguo's avatar
yuguo committed
1811

wenjh's avatar
wenjh committed
1812
1813
1814
1815
1816
    // Check if the userArgs for this device exists
    auto device_it = d_userArgs_map_.find(device_id);
    if (device_it != d_userArgs_map_.end()) {
      return device_it->second;
    }
yuguo's avatar
yuguo committed
1817

wenjh's avatar
wenjh committed
1818
1819
1820
    // Create a new userArgs for this device if it doesn't exist
    hipblaslt_ext::UserArguments* d_userArgs;
    NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
yuguo's avatar
yuguo committed
1821

wenjh's avatar
wenjh committed
1822
1823
1824
1825
    // Store the userArgs in the map for this device
    d_userArgs_map_[device_id] = d_userArgs;
    return d_userArgs;
  }
yuguo's avatar
yuguo committed
1826

wenjh's avatar
wenjh committed
1827
1828
1829
1830
 private:
  std::unordered_map<int, hipblaslt_ext::UserArguments*>
      d_userArgs_map_;  // Map from device_id to hipblasHandle
  std::mutex mutex_;
yuguo's avatar
yuguo committed
1831
1832
};

1833
// Define a static userArgs manager
yuguo's avatar
yuguo committed
1834
1835
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
1836

wenjh's avatar
wenjh committed
1837
1838
1839
1840
1841
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
                          std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
                          std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
                          hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
                          size_t workspaceSize, bool accumulate, bool use_split_accumulator,
1842
1843
1844
1845
                          int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) {
  // Check compute_stream_offset valid.
  NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);

yuguo's avatar
yuguo committed
1846
1847
1848
1849
  int device_id;
  hipGetDevice(&device_id);
  hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
  hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
1850
1851
1852
1853
1854
1855
1856
1857

  // hipblaslt_ext::UserArguments* userArgs;
  // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));

  hipblasLtHandle_t handle = nullptr;
  if (compute_stream_offset != -1) {
    // Init hipblaslt handles (once, globally)
    static std::once_flag init_flag;
yuguo's avatar
yuguo committed
1858
    static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
    std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);

    handle = hipblaslt_handles[compute_stream_offset];
  }

  const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
  const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
  const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);

  hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;

  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;
  int int_one = 1;
  int int_zero = 0;
  int int_beta = int_zero;
  bool use_int8 = false;
wenjh's avatar
wenjh committed
1877

1878
  if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) {
wenjh's avatar
wenjh committed
1879
    NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
1880
1881
1882
1883
1884
1885
    use_int8 = true;
    computeType = HIPBLAS_COMPUTE_32I;
  }

  hipblaslt_ext::GemmPreference gemmPref;
  gemmPref.setMaxWorkspaceBytes(workspaceSize);
wenjh's avatar
wenjh committed
1886
  hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
wenjh's avatar
wenjh committed
1887
                                         computeType);
1888
1889
1890

  std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
      hipblaslt_ext::
wenjh's avatar
wenjh committed
1891
          GemmEpilogue()};  // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
1892
  std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
wenjh's avatar
wenjh committed
1893
1894
1895
1896
1897
1898
1899
  for (int i = 0; i < m.size(); i++) {
    inputs[i].a = inputA[i]->data.dptr;
    inputs[i].b = inputB[i]->data.dptr;
    inputs[i].c = outputD[i]->data.dptr;
    inputs[i].d = outputD[i]->data.dptr;
    inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
    inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
1900
1901
1902
1903
1904
1905
  }
  // hipblaslt_ext::GemmEpilogue supports broadcasting
  groupedgemm.setProblem(m, n, k, b, epilogue, inputs);

  const int request_solutions = 1;
  std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
wenjh's avatar
wenjh committed
1906
  NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
1907

wenjh's avatar
wenjh committed
1908
1909
1910
  if (heuristicResult.empty()) {
    std::cerr << "No valid solution found!" << std::endl;
    return;
1911
1912
  }

wenjh's avatar
wenjh committed
1913
1914
  // Make sure to initialize everytime the algo changes
  NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
wenjh's avatar
wenjh committed
1915

1916
  // Get the default values from the grouepdgemm object
yuguo's avatar
yuguo committed
1917
  groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
1918
1919
1920
  // Copy them to device memory
  // hipblaslt_ext::UserArguments* d_userArgs;
  // NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
wenjh's avatar
wenjh committed
1921
  NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
1922
                            hipMemcpyHostToDevice));
wenjh's avatar
wenjh committed
1923

yuguo's avatar
yuguo committed
1924
1925
1926
  NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
  // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
  // NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
1927
1928
1929
1930
1931

  // NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
  // NVTE_CHECK_CUDA(hipFree(userArgs));
}

wenjh's avatar
wenjh committed
1932
#endif  //USE_HIPBLASLT
yuguo's avatar
yuguo committed
1933

wenjh's avatar
wenjh committed
1934
#ifdef USE_ROCBLAS  // Use rocblas + kernel, no fusion
1935

wenjh's avatar
wenjh committed
1936
inline void CreateRocblasHandle(rocblas_handle* handle) {
1937
1938
1939
1940
  NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
}

using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
wenjh's avatar
wenjh committed
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
                  const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
                  int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad,
                  void* workspace, size_t workspaceSize, bool accumulate,
                  bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                  bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
  void* A = inputA->data.dptr;
  void* A_scale_inverse = inputA->scale_inv.dptr;
  void* B = inputB->data.dptr;
  void* B_scale_inverse = inputB->scale_inv.dptr;
  void* C = outputD->data.dptr;
  void* D = outputD->data.dptr;
  void* D_scale = outputD->scale.dptr;
  void* D_amax = outputD->amax.dptr;
  void* bias_ptr = inputBias->data.dptr;
yuguo's avatar
yuguo committed
1956
  const bool bias = bias_ptr != nullptr;
wenjh's avatar
wenjh committed
1957
  void* pre_gelu_out = outputPreGelu->data.dptr;
yuguo's avatar
yuguo committed
1958
  const bool gelu = pre_gelu_out != nullptr;
wenjh's avatar
wenjh committed
1959
  const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
yuguo's avatar
yuguo committed
1960
1961
1962
1963
1964
  const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype);
  const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype);
  const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype);
  const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype);
  const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype);
wenjh's avatar
wenjh committed
1965

yuguo's avatar
yuguo committed
1966
1967
1968
1969
1970
  // check consistency of arguments:
  // if fp8 is desired, context cannot be null
  // fp8 + gelu fusion + fp8 aux is unavailable right now.
  if (use_fp8 && gelu) {
    NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
wenjh's avatar
wenjh committed
1971
               "fp8 Aux output for gemm + gelu fusion not supported!");
yuguo's avatar
yuguo committed
1972
1973
  }
  if (is_fp8_dtype(outputD->data.dtype)) {
wenjh's avatar
wenjh committed
1974
    NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
yuguo's avatar
yuguo committed
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
  }
  // fp8 + grad unavailable in upstream
  NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");

  float one = 1.0;
  float zero = 0.0;
  float beta = (accumulate) ? one : zero;

  float alpha = 1.0;
  if (use_fp8) {
wenjh's avatar
wenjh committed
1985
1986
1987
1988
    float A_scale_inv, B_scale_inv;
    (void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
    (void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
    alpha = A_scale_inv * B_scale_inv;
yuguo's avatar
yuguo committed
1989
  }
wenjh's avatar
wenjh committed
1990

1991
  rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
yuguo's avatar
yuguo committed
1992
1993
1994
1995
  NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));

  // extract the stream order alloc env
  bool stream_order_alloc = false;
wenjh's avatar
wenjh committed
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
  if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) {
    if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true;
  }

  int64_t ld_gelumat = (int64_t)ldd;

  NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
              D_type == rocblas_datatype_f16_r) ||
                 (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
                  D_type == rocblas_datatype_bf16_r) ||
                 (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f16_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_bf16_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f8_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_bf8_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
                  D_type == rocblas_datatype_f16_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
                  D_type == rocblas_datatype_bf16_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
                  D_type == rocblas_datatype_f8_r) ||
                 (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
                  D_type == rocblas_datatype_bf8_r) ||
                 (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f32_r) ||
                 (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f16_r) ||
                 (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_bf16_r) ||
                 (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_f8_r) ||
                 (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
                  D_type == rocblas_datatype_bf8_r),
             "Only the following combinations of data types are enabled now!\n\
yuguo's avatar
yuguo committed
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
1. input: fp32, output: fp32.\n\
2. input: fp16, output: fp16.\n\
3. input: bf16, output: bf16.\n\
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");

  //If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place.
  // with bias or gelu, allocate fp32 D_temp if the output is not fp32
  // with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported)
  // with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation
  void* D_temp;
wenjh's avatar
wenjh committed
2053
2054
2055
2056
2057
2058
2059
  if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
      (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
                   D_type == rocblas_datatype_bf8_r))) {
    if (!stream_order_alloc) {
      NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n));
    } else {
      NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream));
yuguo's avatar
yuguo committed
2060
    }
wenjh's avatar
wenjh committed
2061
  } else {
yuguo's avatar
yuguo committed
2062
2063
2064
2065
2066
    D_temp = D;
  }

  // When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
  rocblas_datatype D_temp_type = rocblas_datatype_f32_r;
wenjh's avatar
wenjh committed
2067
2068
  if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
                          D_type == rocblas_datatype_f16_r)) {
yuguo's avatar
yuguo committed
2069
2070
2071
    D_temp_type = rocblas_datatype_f16_r;
  }
  // When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16
wenjh's avatar
wenjh committed
2072
2073
  if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
                          D_type == rocblas_datatype_bf16_r)) {
yuguo's avatar
yuguo committed
2074
2075
2076
    D_temp_type = rocblas_datatype_bf16_r;
  }
  // When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case.
wenjh's avatar
wenjh committed
2077
  if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) {
yuguo's avatar
yuguo committed
2078
2079
    D_temp_type = rocblas_datatype_f16_r;
  }
wenjh's avatar
wenjh committed
2080
2081

  if (accumulate && (D_temp != D || D_temp_type != D_type)) {
yuguo's avatar
yuguo committed
2082
    DType output_dtype = get_transformer_engine_dtype(D_type);
wenjh's avatar
wenjh committed
2083
2084
2085
2086
2087
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
        output_dtype, OType,
        //D_temp allocated only with fp32
        detail::identity_kernelLauncher<OType, float>(
            reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream););
yuguo's avatar
yuguo committed
2088
2089
2090
2091
2092
  }

  // D = alpha * (A * B) + beta * C
  if (use_fp8) {
    rocblas_computetype computeType = rocblas_compute_type_f32;
wenjh's avatar
wenjh committed
2093
2094
2095
2096
2097
    NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
                                        B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
                                        D_temp_type, ldd, computeType,
                                        rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0));
  } else {
yuguo's avatar
yuguo committed
2098
2099
    rocblas_datatype computeType = rocblas_datatype_f32_r;
    uint32_t flags = rocblas_gemm_flags_none;
wenjh's avatar
wenjh committed
2100
2101
    if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) {
      flags = rocblas_gemm_flags_fp16_alt_impl;
yuguo's avatar
yuguo committed
2102
    }
wenjh's avatar
wenjh committed
2103
2104
2105
2106
    NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
                                       B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
                                       D_temp_type, ldd, computeType,
                                       rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags));
yuguo's avatar
yuguo committed
2107
2108
2109
2110
2111
2112
  }

  int batch_size, input_dim, output_dim;
  if (bias && gelu) {
    if (grad) {
      // epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
wenjh's avatar
wenjh committed
2113
2114
      // Apply GELU gradient to D_temp and store in D
      // Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr;
yuguo's avatar
yuguo committed
2115
2116
2117
2118
2119
2120
2121
2122
2123
      // This case is NN
      // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
      // The bias vector length is m. So it will be reduced along axis 0 in row major
      // (TODO): The cublasLt doc is not very clear wrt the bias gradient here.
      // It does not explicitly say that it goes through GELU gradient first. We will need to
      // confirm in the future. As of now, my implementation for the bias gradient takes
      // the GELU gradient result in lower precision (D). It might be better to take the GELU
      // gradient result in fp32 but as it requires some kernel changes I would only do that
      // once we confirm that this is the right form of the epilogue.
wenjh's avatar
wenjh committed
2124
      // This is for linear1 -> gelu -> linear2
yuguo's avatar
yuguo committed
2125
2126
2127
      // compute dX = dY * W for linear2
      // gemm_ex(A=W, B=dY)
      batch_size = n;
wenjh's avatar
wenjh committed
2128
2129
      input_dim =
          m;  // input dimension of the second linear layer is the output dimension of the first linear layer
yuguo's avatar
yuguo committed
2130
2131
2132
      output_dim = k;
      DType output_dtype = get_transformer_engine_dtype(D_type);
      DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
wenjh's avatar
wenjh committed
2133
2134
2135
2136
2137
2138
2139
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
              gelu_dtype, GType,
              detail::gelu_backward_kernelLauncher<OType, GType>(
                  reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
                  reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
yuguo's avatar
yuguo committed
2140
2141
2142

      void* bias_tmp;
      if (bias_type != rocblas_datatype_f32_r) {
wenjh's avatar
wenjh committed
2143
2144
2145
2146
2147
2148
        if (!stream_order_alloc) {
          NVTE_CHECK_CUDA(hipMalloc(
              &bias_tmp,
              sizeof(float) * input_dim));  // The bias gradient is for the first linear layer
        } else {
          NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream));
yuguo's avatar
yuguo committed
2149
        }
wenjh's avatar
wenjh committed
2150
      } else {
yuguo's avatar
yuguo committed
2151
2152
2153
        bias_tmp = bias_ptr;
      }

wenjh's avatar
wenjh committed
2154
2155
2156
2157
2158
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          detail::bias_gradient_kernelLauncher<OType>(
              reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
              input_dim, stream_order_alloc, stream););
yuguo's avatar
yuguo committed
2159
2160
2161

      if (bias_type != rocblas_datatype_f32_r) {
        DType bias_dtype = get_transformer_engine_dtype(bias_type);
wenjh's avatar
wenjh committed
2162
2163
2164
2165
2166
2167
2168
2169
2170
        TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
            bias_dtype, BType,
            detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
                                                          reinterpret_cast<BType*>(bias_ptr),
                                                          input_dim, stream););
        if (!stream_order_alloc) {
          NVTE_CHECK_CUDA(hipFree(bias_tmp));
        } else {
          NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
yuguo's avatar
yuguo committed
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
        }
      }

    } else {
      // epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
      // Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D
      // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
      // gemm_ex(A=W, B=X, transA=T)
      batch_size = n;
      input_dim = k;
      output_dim = m;
      DType output_dtype = get_transformer_engine_dtype(D_type);
      DType bias_dtype = get_transformer_engine_dtype(bias_type);
      DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
wenjh's avatar
wenjh committed
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
              gelu_dtype, GType,
              TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
                  bias_dtype, BType,
                  detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(
                      reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
                      reinterpret_cast<GType*>(pre_gelu_out),
                      reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
                      reinterpret_cast<const float*>(D_scale), batch_size, output_dim,
                      stream););););
yuguo's avatar
yuguo committed
2197
    }
wenjh's avatar
wenjh committed
2198
  } else if (bias) {
yuguo's avatar
yuguo committed
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
    if (grad) {
      // grad output is always input B
      // epilogue = CUBLASLT_EPILOGUE_BGRADB;
      // Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n
      // As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major.
      // bias gradient vector length is n. So it will be reduced along axis 0 in row major.
      // The backward pass calculate the bias gradient along with dW = dY^T * X
      // gemm_ex(A=X, B = dY, transB=T)
      batch_size = k;
      input_dim = m;
      output_dim = n;
wenjh's avatar
wenjh committed
2210
      void* bias_tmp;
yuguo's avatar
yuguo committed
2211
      if (bias_type != rocblas_datatype_f32_r) {
wenjh's avatar
wenjh committed
2212
2213
2214
2215
        if (!stream_order_alloc) {
          NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim));
        } else {
          NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream));
yuguo's avatar
yuguo committed
2216
        }
wenjh's avatar
wenjh committed
2217
      } else {
yuguo's avatar
yuguo committed
2218
2219
2220
2221
2222
2223
        bias_tmp = bias_ptr;
      }

      DType input_dtype = get_transformer_engine_dtype(B_type);
      DType output_dtype = get_transformer_engine_dtype(D_type);
      DType bias_dtype = get_transformer_engine_dtype(bias_type);
wenjh's avatar
wenjh committed
2224
2225
2226
2227
2228
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          input_dtype, IType,
          detail::bias_gradient_kernelLauncher<IType>(
              reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
              output_dim, stream_order_alloc, stream););
yuguo's avatar
yuguo committed
2229
      if (bias_type != rocblas_datatype_f32_r) {
wenjh's avatar
wenjh committed
2230
2231
2232
2233
2234
2235
2236
2237
2238
        TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
            bias_dtype, BType,
            detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
                                                          reinterpret_cast<BType*>(bias_ptr),
                                                          output_dim, stream););
        if (!stream_order_alloc) {
          NVTE_CHECK_CUDA(hipFree(bias_tmp));
        } else {
          NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
yuguo's avatar
yuguo committed
2239
2240
2241
        }
      }
      if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
wenjh's avatar
wenjh committed
2242
2243
2244
2245
2246
        TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
            output_dtype, OType,
            detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
                                                          reinterpret_cast<OType*>(D),
                                                          input_dim * output_dim, stream););
yuguo's avatar
yuguo committed
2247
2248
2249
      }
    } else {
      // epilogue = CUBLASLT_EPILOGUE_BIAS;
wenjh's avatar
wenjh committed
2250
      // Broadcast bias and add it to D_temp and store in D. The bias vector length is m
yuguo's avatar
yuguo committed
2251
2252
2253
2254
2255
2256
2257
      // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
      // gemm_ex(A=W, B=X, transA=T)
      batch_size = n;
      input_dim = k;
      output_dim = m;
      DType output_dtype = get_transformer_engine_dtype(D_type);
      DType bias_dtype = get_transformer_engine_dtype(bias_type);
wenjh's avatar
wenjh committed
2258
2259
2260
2261
2262
2263
2264
2265
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
              bias_dtype, BType,
              detail::add_bias_kernelLauncher<OType, BType>(
                  reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
                  reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
                  reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream);););
yuguo's avatar
yuguo committed
2266
    }
wenjh's avatar
wenjh committed
2267
  } else if (gelu) {
yuguo's avatar
yuguo committed
2268
2269
2270
2271
    if (grad) {
      // epilogue = CUBLASLT_EPILOGUE_DGELU;
      // Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
      // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
wenjh's avatar
wenjh committed
2272
      // gemm_ex(A=W, B=dY)
yuguo's avatar
yuguo committed
2273
2274
2275
2276
2277
      batch_size = n;
      input_dim = m;
      output_dim = k;
      DType output_dtype = get_transformer_engine_dtype(D_type);
      DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
wenjh's avatar
wenjh committed
2278
2279
2280
2281
2282
2283
2284
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
              gelu_dtype, GType,
              detail::gelu_backward_kernelLauncher<OType, GType>(
                  reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
                  reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
yuguo's avatar
yuguo committed
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
    } else {
      // epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
      // Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
      // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
      // gemm_ex(A=W, B=X, transA=T)
      batch_size = n;
      input_dim = k;
      output_dim = m;

      DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
wenjh's avatar
wenjh committed
2295
2296
2297
2298
2299
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          gelu_dtype, GType,
          detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
                                                        reinterpret_cast<GType*>(pre_gelu_out),
                                                        batch_size * output_dim, stream););
yuguo's avatar
yuguo committed
2300
      DType output_dtype = get_transformer_engine_dtype(D_type);
wenjh's avatar
wenjh committed
2301
2302
2303
2304
2305
2306
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          detail::gelu_forward_kernelLauncher<OType>(
              reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
              reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
              output_dim, stream););
yuguo's avatar
yuguo committed
2307
    }
wenjh's avatar
wenjh committed
2308
2309
2310
  } else {  // No epilogue - !(bias || gelu)
    if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
                    D_type == rocblas_datatype_bf8_r)) {
yuguo's avatar
yuguo committed
2311
      DType output_dtype = get_transformer_engine_dtype(D_type);
wenjh's avatar
wenjh committed
2312
2313
2314
2315
2316
2317
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output_dtype, OType,
          detail::identity_output_kernelLauncher<OType>(
              reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
              reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n,
              stream););
yuguo's avatar
yuguo committed
2318
2319
    }
  }
wenjh's avatar
wenjh committed
2320
2321
2322
2323
2324
2325
2326
2327

  if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
      (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
                   D_type == rocblas_datatype_bf8_r))) {
    if (!stream_order_alloc) {
      NVTE_CHECK_CUDA(hipFree(D_temp));
    } else {
      NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream));
yuguo's avatar
yuguo committed
2328
2329
2330
2331
    }
  }
}

wenjh's avatar
wenjh committed
2332
#endif  //USE_ROCBLAS
yuguo's avatar
yuguo committed
2333

wenjh's avatar
wenjh committed
2334
2335
2336
2337
void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
                 const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
                 int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
                 size_t workspaceSize, bool accumulate, bool use_split_accumulator,
yuguo's avatar
yuguo committed
2338
                 int math_sm_count, int m_split, int n_split, bool gemm_producer,
wenjh's avatar
wenjh committed
2339
2340
2341
                 const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0,
                 bool nvte_use_rocblas = 0, int compute_stream_offset = -1) {
  /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
yuguo's avatar
yuguo committed
2342
2343
2344
2345
  If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
  Otherwise use ROCBLAS 
*/

yuguo's avatar
yuguo committed
2346
2347
  bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
  bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas;
yuguo's avatar
yuguo committed
2348
2349
2350
2351

#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified
#elif !defined(USE_HIPBLASLT)
wenjh's avatar
wenjh committed
2352
  if (use_hipblaslt) {
yuguo's avatar
yuguo committed
2353
    use_hipblaslt = false;
yuguo's avatar
yuguo committed
2354
    use_rocblas = true;
yuguo's avatar
yuguo committed
2355
2356
2357
    std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
  }
#elif !defined(USE_ROCBLAS)
wenjh's avatar
wenjh committed
2358
  if (use_rocblas) {
yuguo's avatar
yuguo committed
2359
    use_rocblas = false;
yuguo's avatar
yuguo committed
2360
    use_hipblaslt = true;
yuguo's avatar
yuguo committed
2361
2362
2363
    std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
  }
#else
wenjh's avatar
wenjh committed
2364
  if (use_hipblaslt && use_rocblas) {
yuguo's avatar
yuguo committed
2365
    use_rocblas = false;
yuguo's avatar
yuguo committed
2366
    use_hipblaslt = true;
yuguo's avatar
yuguo committed
2367
    // std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
wenjh's avatar
wenjh committed
2368
  } else if (!use_hipblaslt && !use_rocblas) {
yuguo's avatar
yuguo committed
2369
2370
    use_rocblas = false;
    use_hipblaslt = true;
yuguo's avatar
yuguo committed
2371
    // std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
yuguo's avatar
yuguo committed
2372
2373
2374
2375
  }
#endif

#ifdef USE_HIPBLASLT
wenjh's avatar
wenjh committed
2376
  if (use_hipblaslt || !use_rocblas) {
yuguo's avatar
yuguo committed
2377
    // Check compute_stream_offset valid.
yuguo's avatar
yuguo committed
2378
    NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
yuguo's avatar
yuguo committed
2379
2380
2381
2382
2383

    hipblasLtHandle_t handle = nullptr;
    if (compute_stream_offset != -1) {
      // Init hipblaslt handles (once, globally)
      static std::once_flag init_flag;
yuguo's avatar
yuguo committed
2384
      static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
yuguo's avatar
yuguo committed
2385
2386
2387
2388
2389
      std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);

      handle = hipblaslt_handles[compute_stream_offset];
    }

wenjh's avatar
wenjh committed
2390
2391
2392
2393
    hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
                   (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
                   grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
                   m_split, n_split, gemm_producer, inputCounter, stream, handle);
yuguo's avatar
yuguo committed
2394

yuguo's avatar
yuguo committed
2395
2396
2397
2398
2399
    return;
  }
#endif

#ifdef USE_ROCBLAS
wenjh's avatar
wenjh committed
2400
2401
2402
2403
2404
2405
  if (use_rocblas) {
    rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
                 (transa) ? rocblas_operation_transpose : rocblas_operation_none,
                 (transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
                 workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
                 gemm_producer, inputCounter, stream);
yuguo's avatar
yuguo committed
2406
2407
2408
2409
  }
#endif
}

wenjh's avatar
wenjh committed
2410
}  //namespace transformer_engine