skinny_gemms.cu 84.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include <stdexcept>
#include <algorithm>

12
#include "../cuda_compat.h"
13
#include "dispatch_utils.h"
14
#include "quantization/w8a8/fp8/common.cuh"
15

16
17
18
19
20
21
22
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
// issues, do not take strides, and are unable to handle PyTorch tensors that
// return is_contiguous() as False (the tensors may actually be contiguous
// in memory).
//
// However, it may be possible to fix these kernels to handle both issues.

23
24
25
#if defined(__HIPCC__) && \
    (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
  #define __HIP__GFX9__
26
27
#endif

28
29
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
  #define __HIP__MI3XX__
30
31
#endif

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#if defined(__gfx950__)
  #define LDS_SIZE 160 * 1024
#else
  #define LDS_SIZE 64 * 1024
#endif

int get_lds_size() {
  static bool is_cached = false;
  static int result;
  if (is_cached == false) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
    std::string device_arch = dprops->gcnArchName;
    size_t substring = device_arch.find("gfx95");
    result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
    is_cached = true;
  }
  return result;
}

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#if defined(NDEBUG)
  #undef NDEBUG
  #include <assert.h>
  #define UNREACHABLE_CODE assert(false);
  #define NDEBUG
#else
  #define UNREACHABLE_CODE assert(false);
#endif

template <typename T>
struct scalar {};

template <typename T>
struct scalar2 {};

template <typename T>
__device__ __forceinline__ float2 __s22float2(T v);

template <typename T>
__device__ __forceinline__ T __float2s(float v);

template <typename T>
__device__ __forceinline__ T __float22s2_rn(float2 v);

// Definitions and cvt functions for fp16
template <>
struct scalar<c10::Half> {
  using type = half;
};

template <>
struct scalar2<c10::Half> {
  using type = __half2;
};

template <>
__device__ __forceinline__ half __float2s(float v) {
  return __float2half(v);
}

template <>
__device__ __forceinline__ float2 __s22float2(__half2 v) {
  return __half22float2(v);
}

template <>
__device__ __forceinline__ __half2 __float22s2_rn(float2 v) {
  return __float22half2_rn(v);
}

// Definitions and cvt functions for bf16
template <>
struct scalar<c10::BFloat16> {
  using type = __hip_bfloat16;
};

template <>
struct scalar2<c10::BFloat16> {
  using type = __hip_bfloat162;
};

template <>
__device__ __forceinline__ __hip_bfloat16 __float2s(float v) {
  return __float2bfloat16(v);
}

template <>
__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) {
  return __bfloat1622float2(v);
}

template <>
__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) {
  return __float22bfloat162_rn(v);
}

template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
  return __builtin_nontemporal_load(addr);
}

__device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
  auto addr_alias = reinterpret_cast<const float*>(addr);
  auto dat0 = loadnt(addr_alias);
  auto dat1 = loadnt(addr_alias + 1);
  auto dat2 = loadnt(addr_alias + 2);
  auto dat3 = loadnt(addr_alias + 3);
  return make_float4(dat0, dat1, dat2, dat3);
}

// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <typename scalar_t, int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
                               scalar_t* out_c, const int K) {
  using scalar2_t = typename scalar2<scalar_t>::type;
  auto af4 = reinterpret_cast<const float4*>(in_a);
  auto bf4 = reinterpret_cast<const scalar2_t*>(in_b);
  auto c = reinterpret_cast<scalar2_t*>(out_c);
  __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
  const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
  const int threadid = threadIdx.x;
  const int warp = threadIdx.x / WARP_SIZE;
  const int lane = threadIdx.x % WARP_SIZE;
  const int num_warps = blockDim.x / WARP_SIZE;
156
157
  const int qwarpid = threadid / 16;
  const int qthreadid = threadid % 16;
158
159
160
161
162
163
164
165
166
167
168
169
170
171
  float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
  scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
  float acc[NUM_A_ROWS_PER_BLOCK];
  scalar2_t acch2;
  scalar2_t oval;

  // As we later use warp shuffle operations, we may have more threads in the
  // block than the actual available data, hence the if guard here.
  if (threadid * 8 < K) {
#pragma unroll
    for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
      // rowA_elem4[i] holds 8 * half numbers seen as a single float4.
      rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
    }
172
173
174
175
    colB_elem4x = bf4[threadid * 4 + 0];
    colB_elem4y = bf4[threadid * 4 + 1];
    colB_elem4z = bf4[threadid * 4 + 2];
    colB_elem4w = bf4[threadid * 4 + 3];
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
  }

  scalar2_t Af2;
  float2 S;

  auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
  scalar2_t* ah2lptr;

#pragma unroll
  for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
    // Multiply-add on 8 scalar_t.
    ah2lptr = Ah2ptr + i * 4;
    Af2 = *(ah2lptr);
    acch2 = __hmul2(Af2, colB_elem4x);
    Af2 = *(ah2lptr + 1);
    acch2 = __hfma2(Af2, colB_elem4y, acch2);
    Af2 = *(ah2lptr + 2);
    acch2 = __hfma2(Af2, colB_elem4z, acch2);
    Af2 = *(ah2lptr + 3);
    acch2 = __hfma2(Af2, colB_elem4w, acch2);
    S = __s22float2(acch2);

    // See comment above concerning the if guard.
    acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f);
  }

// all reduce across warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
    for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
      acc[i] += __shfl_xor(acc[i], mask);
    }
  }

  // Warp leaders store the data to shared memory.
  if (lane < NUM_A_ROWS_PER_BLOCK) {
    red_smem[lane][warp] = acc[lane];
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
    acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
221
222
#pragma unroll
    for (int mask = 16 / 2; mask >= 1; mask /= 2) {
223
224
      acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
    }
225
    float oval2 = __shfl_xor(acc[qwarpid], 16);
226

227
    if (lane % 32 == 0) {
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
      oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
      c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
    }
  }
}

torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
                    const int64_t rows_per_block) {
  auto M = in_a.size(0);
  auto K = in_a.size(1);
  auto N = in_b.size(0);

  TORCH_CHECK(N == 1, "Row number of activation tensor must be 1.");
  TORCH_CHECK(in_a.dtype() == in_b.dtype());
  TORCH_CHECK(in_b.dtype() == torch::kFloat16 ||
              in_b.dtype() == torch::kBFloat16);

  auto out_c = torch::empty(
      {N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));

  // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
  // operations.
  const int NUM_THREADS =
251
252
253
254
      max(rows_per_block * 16,
          K * 2 / 16 % WARP_SIZE == 0
              ? K * 2 / 16
              : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE));
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

  int NUM_BLOCKS = M / rows_per_block;

  const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  // call the kernel function...
  AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] {
    auto a_ptr = in_a.data_ptr<scalar_t>();
    auto b_ptr = in_b.data_ptr<scalar_t>();
    auto c_ptr = out_c.data_ptr<scalar_t>();
    if (rows_per_block == 2) {
      LLGemm1_kernel<scalar_t, 2>
          <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
    } else if (rows_per_block == 4) {
      LLGemm1_kernel<scalar_t, 4>
          <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
    } else if (rows_per_block == 8) {
      LLGemm1_kernel<scalar_t, 8>
          <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
    } else if (rows_per_block == 16) {
      LLGemm1_kernel<scalar_t, 16>
          <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
    } else {
      NUM_BLOCKS = M / 4;
      LLGemm1_kernel<scalar_t, 4>
          <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
    }
  });

  return out_c;
}

#define DOT2C(V0, V2, V3)                                                     \
  if constexpr (std::is_same_v<scalar_t, half>) {                             \
    asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
  } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {            \
    float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) *             \
               __bfloat1622float2(*((__hip_bfloat162*)(&(V3))));              \
    V0 += (s.x + s.y);                                                        \
  }

297
298
299
300
301
// To avoid LLVM silently upcasting to double
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
  return min(a, b);
}

302
#if defined(__HIP__GFX9__)  // TODO: Add NAVI support
303
304
305
306
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
307
308
309
    wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
                     const scalar_t* B, const scalar_t* __restrict__ A,
                     const scalar_t* __restrict__ BIAS, scalar_t* C,
310
                     const int _WvPrGrp, const int CuCount) {
311
312
  constexpr int max_lds_len = LDS_SIZE / 2;
  #if defined(__HIP__MI3XX__)
313
314
315
316
317
  constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
  #else
  constexpr bool use_mfma = false;
  #endif

318
319
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
320
321
  using half4 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
322
323
324
325
326
  union bigType {
    scalar_t h[A_CHUNK];
    float f[A_CHUNK / 2];
    float2 f2[A_CHUNK / 4];
    double d[A_CHUNK / 4];
327
    half4 h4[A_CHUNK / 4];
328
329
330
331
    scalar8 h8;
  };

  //----------------------------------------------------
332
  // Reserving 64/160 KB of LDS to have 1 WG / CU
333
334
335
  // Goal is to bring the activation matrix A to the LDS
  // and use it across the lifetime of the work group
  // TODO: When activation matrix is larger than 64 KB
336
  //	     then this is not going to work!
337
  //----------------------------------------------------
338
  __shared__ scalar_t s[max_lds_len];
339
340
341
342
343
344
345
346
347
348

  //----------------------------------------------------
  // Fetch the activation matrix to LDS
  // Loop iteration:
  // - Each thread (lane) is fetching 8 elements (A_Chunk)
  // - Each wave will fetch 64*8=> 512 elements
  // - Each WG will fetch 512 * 16 => 8K elements
  // - Then the WG will move to another 8 K elements
  // TODO: Logic below will only work when K is multiple of 8
  //----------------------------------------------------
349
  for (uint32_t k = 0; k < min__(K * N, max_lds_len);
350
351
352
       k += THRDS * WvPrGrp * A_CHUNK) {
    uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);

353
    if (k_in >= min__(K * N, max_lds_len)) break;
354
355
356
357
358
359
360
361
362
363

    *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
  }
  __syncthreads();

  if (threadIdx.y >= _WvPrGrp) return;

  uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;

  float sum[N][YTILE];
364
  scalar8 sum4[N][YTILE];
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

  //----------------------------------------------------
  // Each wave works on a single column of weight matrix.
  // There are 16 waves per WG, and hence, each WG is
  // working on 16 columns of weight matrix. Moreover,
  // we tile in column direction by YTILE, so when YTILE=1
  // the above math is right, however, when YTILE=2 then
  // each wave  will be working on 2 columns and WG will
  // be working on 32 columns.
  //
  // Top level loop that makes WGs persistent!
  // - WGs iterates across columns of weight matrix
  // - Each wave within WG works on a given column(s)
  // - After completing first set of columns, WGs start
  //   working on the next set of available columns
  //----------------------------------------------------
  while (m < M) {
    //----------------------------------------------------
    // 'sum' accumulates the matrix A x B computation
    // split across 64 lanes.
    //
    // YTILE represents how many column of weight matrix
    // are being worked on by each wave.
    //----------------------------------------------------
    for (int i = 0; i < YTILE; i++)
390
391
392
393
394
      for (int n = 0; n < N; n++)
        if constexpr (!use_mfma)
          sum[n][i] = 0;
        else
          sum4[n][i] = {0, 0, 0, 0};
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

    bigType bigA[N][UNRL];
    bigType bigB[YTILE][UNRL];
    //----------------------------------------------------
    // Fetch weight matrix B in interleaved K-split!
    // - Each thread (lane) is fetching 8 elements (A_Chunk)
    // - Each wave will fetch 64*8=> 512 elements (1024B)
    // - YTILE represents the number of column being serviced
    //   by wave
    // - Loop for fetching weight matrix (B) are unrolled
    //
    // Fetch activation matrix A from LDS
    // - Loop for fetching activation matrix (A) are unrolled
    //
    // Finally, do the matrix multiplication in an unrolled
    // fashion. This provides lot of food for compiler
    // scheduling.
    //
    // TODO: Logic below will only work when K is multiple of 8
    //----------------------------------------------------
    // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
    for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
      // Fetch the weight matrix from memory!
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        const scalar_t* B_ = &B[(m + 0) * K + k_];
425
426
        for (int y = 0; y < YTILE; y++)
          bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
      }

      // Fetch activation matrix from either just LDS or from both LDS / memory
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        // Fetch A activation matrix in interleaved fashion from LDS or memory

        for (int n = 0; n < N; n++) {
          bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
        }
      }

      // Do the matrix multiplication in interleaved manner
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;
        // Do the matrix multiplication of activation and weight matrix
        // - Remember the accumulation is happening for K-split of 64!
  #pragma unroll
        for (uint32_t n = 0; n < N; n++) {
  #pragma unroll
454
455
456
457
458
459
460
461
462
463
464
          for (int y = 0; y < YTILE; y++) {
            if constexpr (!use_mfma)
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
                DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
              }
            else
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 4; b++)
                sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
                    bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
465
466
467
468
469
470
471
472
          }
        }
      }
    }

    //----------------------------------------------------
    // Final reduction step using shuffle
    //----------------------------------------------------
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    if constexpr (!use_mfma) {
      for (int n = 0; n < N; n++) {
        for (int y = 0; y < YTILE; y++) {
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
        }
495
      }
496
497
498
499

      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
500
501
502
503
504
505
506
507
            if constexpr (std::is_same_v<scalar_t, half>) {
              if (BIAS)
                sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
            } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
              if (BIAS)
                sum[n][i] +=
                    __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
            }
508
509
510
511
512
513
            C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
          }
        }
      }
    } else {
  #pragma unroll
514
      for (int n = 0; n < N; n++) {
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
  #pragma unroll
        for (int y = 0; y < YTILE; y++) {
          // float accm1 = 0;
          // for (int i=0; i<64; i++)
          //    accm1 += __shfl(sum4[n][y][i%4], i);
          float accm = sum4[n][y][0];
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));

          sum4[n][y][0] = accm;
        }
      }
      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
552
553
554
            if (BIAS)
              sum4[n][i][0] +=
                  __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
555
556
            C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
          }
557
558
559
560
561
562
        }
      }
    }
    m += CuCount * _WvPrGrp * YTILE;
  }
}
563
#else   // !defined(__HIP__GFX9__) TODO: Add NAVI support
564
565
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
566
567
568
569
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
                                 const int By, const scalar_t* B,
                                 const scalar_t* __restrict__ A,
                                 const scalar_t* __restrict__ BIAS, scalar_t* C,
570
571
572
                                 const int _WvPrGrp, const int CuCount) {
  UNREACHABLE_CODE
}
573
#endif  // defined(__HIP__GFX9__) TODO: Add NAVI support
574

575
#if defined(__HIP__GFX9__)  // TODO: Add NAVI support
576
577
578
579
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
580
581
582
    wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
                 const scalar_t* B, const scalar_t* __restrict__ A,
                 const scalar_t* __restrict__ BIAS, scalar_t* C,
583
                 const int _WvPrGrp, const int CuCount) {
584
585
  constexpr int max_lds_len = LDS_SIZE / 2;
  #if defined(__HIP__MI3XX__)
586
587
588
589
590
  constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
  #else
  constexpr bool use_mfma = false;
  #endif

591
592
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
593
594
  using half4 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
595
596
597
598
599
  union bigType {
    scalar_t h[A_CHUNK];
    float f[A_CHUNK / 2];
    float2 f2[A_CHUNK / 4];
    double d[A_CHUNK / 4];
600
    half4 h4[A_CHUNK / 4];
601
602
603
604
605
606
607
608
    scalar8 h8;
  };

  //----------------------------------------------------
  // Reserving 64 KB of LDS to have 1 WG / CU
  // Goal is to bring the activation matrix A to the LDS
  // and use it across the lifetime of the work group
  // TODO: When activation matrix is larger than 64 KB
609
  //	     then this is not going to work!
610
  //----------------------------------------------------
611
  __shared__ scalar_t s[max_lds_len];
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628

  //----------------------------------------------------
  // Computation of columns that need to be committed to memory!
  //----------------------------------------------------
  uint32_t commitColumn[YTILE];
  for (uint32_t i = 0; i < YTILE; i++) {
    commitColumn[i] = 1;
  }

  //----------------------------------------------------
  // Indexing function into the column of weight matrix B
  // Algorithm does 64 lane k-splitting / wave and uses
  // WG ID and Thread ID to find the index.
  //----------------------------------------------------
  // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
  uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;

629
  // Check whether there will be fragmentation!
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
  // This will happen only for the last wave!
  if (m < M && (m + YTILE) >= M) {
    uint32_t startColumn = M - YTILE;
    for (uint32_t i = 0; i < (m - startColumn); i++) {
      commitColumn[i] = 0;
    }
    m = startColumn;
  }

  //----------------------------------------------------
  // Fetch the activation matrix to LDS
  // Loop iteration:
  // - Each thread (lane) is fetching 8 elements (A_Chunk)
  // - Each wave will fetch 64*8=> 512 elements
  // - Each WG will fetch 512 * 16 => 8K elements
  // - Then the WG will move to another 8 K elements
  // TODO: Logic below will only work when K is multiple of 8
  //----------------------------------------------------
648
  for (uint32_t k = 0; k < min__(K * N, max_lds_len);
649
650
651
       k += THRDS * WvPrGrp * A_CHUNK) {
    uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);

652
    if (k_in >= min__(K * N, max_lds_len)) break;
653
654
655
656
657
658
659
660
661

    *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
  }

  __syncthreads();

  if (threadIdx.y >= _WvPrGrp) return;

  float sum[N][YTILE];
662
  scalar8 sum4[N][YTILE];
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687

  //----------------------------------------------------
  // Each wave works on a single column of weight matrix.
  // There are 16 waves per WG, and hence, each WG is
  // working on 16 columns of weight matrix. Moreover,
  // we tile in column direction by YTILE, so when YTILE=1
  // the above math is right, however, when YTILE=2 then
  // each wave  will be working on 2 columns and WG will
  // be working on 32 columns.
  //
  // Top level loop that makes WGs persistent!
  // - WGs iterates across columns of weight matrix
  // - Each wave within WG works on a given column(s)
  // - After completing first set of columns, WGs start
  //   working on the next set of available columns
  //----------------------------------------------------
  while (m < M) {
    //----------------------------------------------------
    // 'sum' accumulates the matrix A x B computation
    // split across 64 lanes.
    //
    // YTILE represents how many column of weight matrix
    // are being worked on by each wave.
    //----------------------------------------------------
    for (int i = 0; i < YTILE; i++)
688
689
690
691
692
      for (int n = 0; n < N; n++)
        if constexpr (!use_mfma)
          sum[n][i] = 0;
        else
          sum4[n][i] = {0, 0, 0, 0};
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

    bigType bigA[N][UNRL];
    bigType bigB[YTILE][UNRL];
    //----------------------------------------------------
    // Fetch weight matrix B in interleaved K-split!
    // - Each thread (lane) is fetching 8 elements (A_Chunk)
    // - Each wave will fetch 64*8=> 512 elements (1024B)
    // - YTILE represents the number of column being serviced
    //   by wave
    // - Loop for fetching weight matrix (B) are unrolled
    //
    // Fetch activation matrix A from LDS
    // - Loop for fetching activation matrix (A) are unrolled
    //
    // Finally, do the matrix multiplication in an unrolled
    // fashion. This provides lot of food for compiler
    // scheduling.
    //
    // TODO: Logic below will only work when K is multiple of 8
    //----------------------------------------------------
    for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
      // Fetch the weight matrix from memory!
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        const scalar_t* B_ = &B[(m + 0) * K + k_];
722
723
        for (int b = 0; b < YTILE; b++)
          bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
724
725
726
727
728
729
730
731
732
733
734
735
      }

      // Fetch activation matrix from either just LDS or from both LDS / memory
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        // Fetch A activation matrix in interleaved fashion from LDS or memory

        for (int n = 0; n < N; n++) {
736
          if (k_ + K * n < max_lds_len)
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
            bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
          else
            bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
        }
      }

      // Do the matrix multiplication in interleaved manner
  #pragma unroll
      for (uint32_t n = 0; n < N; n++) {
  #pragma unroll
        for (uint32_t k2 = 0; k2 < UNRL; k2++) {
          uint32_t k = k1 + k2 * THRDS * A_CHUNK;
          uint32_t k_ = k + threadIdx.x * A_CHUNK;
          if (k_ >= K) break;
          // Do the matrix multiplication of activation and weight matrix
          // - Remember the accumulation is happening for K-split of 64!
  #pragma unroll
754
755
756
757
758
759
760
761
762
763
764
          for (int y = 0; y < YTILE; y++) {
            if constexpr (!use_mfma)
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
                DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
              }
            else
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 4; b++)
                sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
                    bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
765
766
767
768
769
770
771
772
          }
        }
      }
    }

    //----------------------------------------------------
    // Final reduction step using shuffle
    //----------------------------------------------------
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    if constexpr (!use_mfma) {
      for (int n = 0; n < N; n++) {
        for (int y = 0; y < YTILE; y++) {
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
        }
795
796
      }

797
798
799
      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
800
801
802
803
804
805
806
807
808
            if (commitColumn[i]) {
              if constexpr (std::is_same_v<scalar_t, half>) {
                if (BIAS)
                  sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
              } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
                if (BIAS)
                  sum[n][i] +=
                      __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
              }
809
              C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
810
            }
811
812
813
814
815
          }
        }
      }
    } else {
  #pragma unroll
816
      for (int n = 0; n < N; n++) {
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
  #pragma unroll
        for (int y = 0; y < YTILE; y++) {
          // float accm1 = 0;
          // for (int i=0; i<64; i++)
          //    accm1 += __shfl(sum4[n][y][i%4], i);

          float accm = sum4[n][y][0];
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));

          sum4[n][y][0] = accm;
        }
      }
      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
855
856
857
858
859
860
            if (commitColumn[i]) {
              if (BIAS)
                sum4[n][i][0] +=
                    __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
              C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
            }
861
          }
862
863
864
865
866
867
        }
      }
    }

    m += CuCount * _WvPrGrp * YTILE;

868
    // Check whether there will be fragmentation!
869
870
871
872
873
874
875
876
877
878
879
    // This will happen only for the last wave!
    if (m < M && (m + YTILE) >= M) {
      uint32_t startColumn = M - YTILE;
      for (uint32_t i = 0; i < (m - startColumn); i++) {
        commitColumn[i] = 0;
      }
      m = startColumn;
    }
  }
}

880
#else   // !defined(__HIP__GFX9__) TODO: Add NAVI support
881
882
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
883
884
885
886
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
                             const int By, const scalar_t* B,
                             const scalar_t* __restrict__ A,
                             const scalar_t* __restrict__ BIAS, scalar_t* C,
887
888
889
                             const int _WvPrGrp, const int CuCount) {
  UNREACHABLE_CODE
}
890
#endif  // defined(__HIP__GFX9__) TODO: Add NAVI support
891

892
#if defined(__HIP__GFX9__)  // TODO: Add NAVI support
893
894
895
896
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
897
898
899
    wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
                     const scalar_t* B, const scalar_t* __restrict__ A,
                     const scalar_t* __restrict__ BIAS, scalar_t* C,
900
                     const int _WvPrGrp, const int CuCount) {
901
902
  constexpr int max_lds_len = LDS_SIZE / 2;
  #if defined(__HIP__MI3XX__)
903
904
905
906
907
  constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
  #else
  constexpr bool use_mfma = false;
  #endif

908
909
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
910
911
  using half4 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
912
913
914
915
916
  union bigType {
    scalar_t h[A_CHUNK];
    float f[A_CHUNK / 2];
    float2 f2[A_CHUNK / 4];
    double d[A_CHUNK / 4];
917
    half4 h4[A_CHUNK / 4];
918
919
920
921
    scalar8 h8;
  };

  //----------------------------------------------------
922
  // Reserving 64/160 KB of LDS to have 1 WG / CU
923
924
925
  // Goal is to bring the activation matrix A to the LDS
  // and use it across the lifetime of the work group
  // TODO: When activation matrix is larger than 64 KB
926
  //	     then this is not going to work!
927
  //----------------------------------------------------
928
  __shared__ scalar_t s[max_lds_len];
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947

  //----------------------------------------------------
  // Computation of columns that need to be committed to memory!
  //----------------------------------------------------
  uint32_t commitColumn[YTILE];
  for (uint32_t i = 0; i < YTILE; i++) {
    commitColumn[i] = 1;
  }

  // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
  if (threadIdx.y >= _WvPrGrp) return;

  //----------------------------------------------------
  // Indexing function into the column of weight matrix B
  // Algorithm does 64 lane k-splitting / wave and uses
  // WG ID and Thread ID to find the index.
  //----------------------------------------------------
  uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;

948
  // Check whether there will be fragmentation!
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
  // This will happen only for the last wave!
  if (m < M && (m + YTILE) >= M) {
    uint32_t startColumn = M - YTILE;
    for (uint32_t i = 0; i < (m - startColumn); i++) {
      commitColumn[i] = 0;
    }
    m = startColumn;
  }

  //----------------------------------------------------
  // Fetch the activation matrix to LDS
  // Loop iteration:
  // - Each thread (lane) is fetching 8 elements (A_Chunk)
  // - Each wave will fetch 64*8=> 512 elements
  // - Each WG will fetch 512 * 16 => 8K elements
  // - Then the WG will move to another 8 K elements
  // TODO: Logic below will only work when K is multiple of 8
  //----------------------------------------------------
  #define PCML
  #ifndef PCML
969
  for (uint32_t k = 0; k < min__(K * N, max_lds_len);
970
971
972
       k += THRDS * WvPrGrp * A_CHUNK) {
    uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);

973
    if (k_in >= min__(K * N, max_lds_len)) break;
974
975
976
977
978
979
980
981
982

    *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
  }
  __syncthreads();
  #endif

  #define TUC (THRDS * UNRL * A_CHUNK)
  uint32_t kBase = 0;
  // find biggest k size that fits in LDS
983
  uint32_t kFit = (max_lds_len) / N;
984
985
986
987
988
989
  // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
  // of TUC
  kFit = (kFit % TUC == 0)
             ? kFit
             : (kFit - kFit % TUC);  // round up to multiple of TUC
  // if (kFit == 0) kFit = TUC;
990
  kFit = min__(kFit, K);
991
992

  float sum[N][YTILE];
993
  scalar8 sum4[N][YTILE];
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

  //----------------------------------------------------
  // Each wave works on a single column of weight matrix.
  // There are 16 waves per WG, and hence, each WG is
  // working on 16 columns of weight matrix. Moreover,
  // we tile in column direction by YTILE, so when YTILE=1
  // the above math is right, however, when YTILE=2 then
  // each wave  will be working on 2 columns and WG will
  // be working on 32 columns.
  //
  // Top level loop that makes WGs persistent!
  // - WGs iterates across columns of weight matrix
  // - Each wave within WG works on a given column(s)
  // - After completing first set of columns, WGs start
  //   working on the next set of available columns
  //----------------------------------------------------
  #ifdef PCML
  int YW = (YTILE * _WvPrGrp);
  uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW);
  while (m < Mrndp) {
  #else
  while (m < M) {
  #endif
    //----------------------------------------------------
    // 'sum' accumulates the matrix A x B computation
    // split across 64 lanes.
    //
    // YTILE represents how many column of weight matrix
    // are being worked on by each wave.
    //----------------------------------------------------
    for (int i = 0; i < YTILE; i++)
1025
1026
1027
1028
1029
      for (int n = 0; n < N; n++)
        if constexpr (!use_mfma)
          sum[n][i] = 0;
        else
          sum4[n][i] = {0, 0, 0, 0};
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077

    bigType bigA[N][UNRL];
    bigType bigB[YTILE][UNRL];
    //----------------------------------------------------
    // Fetch weight matrix B in interleaved K-split!
    // - Each thread (lane) is fetching 8 elements (A_Chunk)
    // - Each wave will fetch 64*8=> 512 elements (1024B)
    // - YTILE represents the number of column being serviced
    //   by wave
    // - Loop for fetching weight matrix (B) are unrolled
    //
    // Fetch activation matrix A from LDS
    // - Loop for fetching activation matrix (A) are unrolled
    //
    // Finally, do the matrix multiplication in an unrolled
    // fashion. This provides lot of food for compiler
    // scheduling.
    //
    // TODO: Logic below will only work when K is multiple of 8
    //----------------------------------------------------
    for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
  #ifdef PCML
      if ((k1 == 0) || (k1 == kBase + kFit)) {  // load next chunk of A[] to LDS
        if (k1 != 0) kBase += kFit;
        __syncthreads();
        for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
          uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
          if (kBase + kOff >= K) break;
          if (kOff >= kFit) break;
          for (uint32_t n = 0; n < N; n++) {
            uint32_t k_in = kBase + n * K + kOff;
            uint32_t k_ot = n * kFit + kOff;
            *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
          }
        }
        __syncthreads();
      }
      if (m >= M) continue;
  #endif

      // Fetch the weight matrix from memory!
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        const scalar_t* B_ = &B[(m + 0) * K + k_];
1078
1079
        for (int b = 0; b < YTILE; b++)
          bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
      }

      // Fetch activation matrix from either just LDS or from both LDS / memory
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;

        // Fetch A activation matrix in interleaved fashion from LDS or memory

        for (int n = 0; n < N; n++) {
  #ifdef PCML
          bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
  #else
          if (k_ + K * n < 32 * 1024)
            bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
          else
            bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
  #endif
        }
      }

      // Do the matrix multiplication in interleaved manner
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;
  #pragma unroll
        for (uint32_t n = 0; n < N; n++) {
          // Do the matrix multiplication of activation and weight matrix
          // - Remember the accumulation is happening for K-split of 64!
  #pragma unroll
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
          for (int y = 0; y < YTILE; y++) {
            if constexpr (!use_mfma)
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
                DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
              }
            else
  #pragma unroll
              for (uint32_t b = 0; b < A_CHUNK / 4; b++)
                sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
                    bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
          }
        }
      }
    }

  #ifdef PCML
    if (m >= M) {
      m += CuCount * _WvPrGrp * YTILE;
      kBase = 0;
      continue;
    }
  #endif

    //----------------------------------------------------
    // Final reduction step using shuffle
    //----------------------------------------------------
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
    if constexpr (!use_mfma) {
      for (int n = 0; n < N; n++) {
        for (int y = 0; y < YTILE; y++) {
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(sum[n][y])
              : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
        }
1163
1164
      }

1165
1166
1167
      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
1168
1169
1170
1171
1172
1173
1174
1175
1176
            if (commitColumn[i]) {
              if constexpr (std::is_same_v<scalar_t, half>) {
                if (BIAS)
                  sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
              } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
                if (BIAS)
                  sum[n][i] +=
                      __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
              }
1177
              C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
1178
            }
1179
1180
1181
1182
1183
          }
        }
      }
    } else {
  #pragma unroll
1184
      for (int n = 0; n < N; n++) {
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
  #pragma unroll
        for (int y = 0; y < YTILE; y++) {
          float accm = sum4[n][y][0];
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
              : "=v"(accm)
              : "0"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));
          asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
              : "=v"(accm)
              : "0"(accm), "v"(accm), "v"(accm));

          sum4[n][y][0] = accm;
        }
      }
      if (threadIdx.x == 63) {
        for (int n = 0; n < N; n++) {
          for (int i = 0; i < YTILE; i++) {
1219
1220
1221
1222
1223
1224
            if (commitColumn[i]) {
              if (BIAS)
                sum4[n][i][0] +=
                    __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
              C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
            }
1225
          }
1226
1227
1228
1229
1230
1231
1232
        }
      }
    }

    m += CuCount * _WvPrGrp * YTILE;
    kBase = 0;

1233
    // Check whether there will be fragmentation!
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
    // This will happen only for the last wave!
    if (m < M && (m + YTILE) >= M) {
      uint32_t startColumn = M - YTILE;
      for (uint32_t i = 0; i < (m - startColumn); i++) {
        commitColumn[i] = 0;
      }
      m = startColumn;
    }
  }
}
1244
#else   // !defined(__HIP__GFX9__) TODO: Add NAVI support
1245
1246
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
          int UNRL, int N>
1247
1248
1249
1250
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
                                 const int By, const scalar_t* B,
                                 const scalar_t* __restrict__ A,
                                 const scalar_t* __restrict__ BIAS, scalar_t* C,
1251
1252
1253
                                 const int _WvPrGrp, const int CuCount) {
  UNREACHABLE_CODE
}
1254
#endif  // defined(__HIP__GFX9__) TODO: Add NAVI support
1255

1256
// Find the min val of div2 that doesn't increase N/(div1*div2)
1257
1258
int mindiv(int N, int div1, int div2) {
  int nPrRnd = div1 * div2;
1259
1260
1261
1262
1263
1264
1265
  int rnds[13];
  for (int i = 0; i < 13; i++) {
    rnds[i] = (N + nPrRnd - 1) / nPrRnd;
    nPrRnd -= div1;
  }
  for (int i = 12; i >= 0; i--)
    if (rnds[0] == rnds[i]) return (div2 - i);
1266
  return 0;
1267
1268
}

1269
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
1270
                       const std::optional<at::Tensor>& in_bias,
1271
1272
1273
1274
                       const int64_t CuCount) {
  auto M_in = in_a.size(0);
  auto K_in = in_a.size(1);
  auto N_in = in_b.size(0);
1275
1276
1277
1278
1279
1280
1281
1282
  auto Bx_in =
      (in_bias.has_value() && in_bias->numel() > 0)
          ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
          : 1;
  auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
                in_bias->sizes().size() == 2)
                   ? in_bias->size(0)
                   : 1;
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296

  TORCH_CHECK(in_a.dtype() == in_b.dtype());
  TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
  TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
              in_a.dtype() == torch::kBFloat16);

  auto out_c = torch::empty(
      {N_in, M_in},
      torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));

  dim3 grid(CuCount);

  const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1297
  const int max_lds_len = get_lds_size() / 2;
1298

1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
#define WVSPLITK(_YTILE, _UNRL, _N)                                        \
  {                                                                        \
    dim3 block(64, 16);                                                    \
    int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16);                    \
    if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0))              \
      wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N>               \
          <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
                                       biasf4, c, __wvPrGrp, CuCount);     \
    else if (K_in * N_in <= max_lds_len * 1.2)                             \
      wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N>                   \
          <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
                                       biasf4, c, __wvPrGrp, CuCount);     \
    else                                                                   \
      wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N>               \
          <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
                                       biasf4, c, __wvPrGrp, CuCount);     \
  }

#define WVSPLIT_TILE(_sYT, __N)                           \
  {                                                       \
    bool fit_lds = (K_in * N_in <= max_lds_len);          \
    if (_sYT <= 1)                                        \
      WVSPLITK(1, 4, __N)                                 \
    else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
      WVSPLITK(2, 2, __N)                                 \
    else if (_sYT <= 4 * 3)                               \
      WVSPLITK(3, 2, __N)                                 \
    else if (__N == 4)                                    \
      WVSPLITK(4, 1, __N)                                 \
    else                                                  \
      WVSPLITK(4, 2, __N)                                 \
1330
1331
1332
1333
1334
1335
  }

  AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
    using fptype = typename scalar<scalar_t>::type;
    fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
    const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
1336
1337
1338
1339
    const fptype* biasf4 =
        (in_bias.has_value() && in_bias->numel() > 0)
            ? reinterpret_cast<const fptype*>(in_bias->data_ptr())
            : nullptr;
1340
    fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
1341
1342
1343
1344
1345

    // first shoot for biggest tile-size that keeps all simd busy,
    // then cut the active waves to balance their distribution...
    int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);

1346
1347
    switch (N_in) {
      case 1:
1348
        WVSPLIT_TILE(sYT, 1)
1349
1350
        break;
      case 2:
1351
        WVSPLIT_TILE(sYT, 2)
1352
1353
        break;
      case 3:
1354
        WVSPLIT_TILE(sYT, 3)
1355
1356
        break;
      case 4:
1357
        WVSPLIT_TILE(sYT, 4)
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
        break;
      default:
        throw std::runtime_error(
            "Unsupported N value: " + std::to_string(M_in) + "," +
            std::to_string(K_in) + "," + std::to_string(N_in));
    }
  });
  return out_c;
}

1368
1369
1370
// This version targets cases skinny where CUs are not filled
// Wave-SplitK is used with reduction done via atomics.
#if defined(__gfx950__)
1371
1372
  #define WVSPLITKRC_1KPASS
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
1373
          int UNRL, int N, int GrpsShrB, int CHUNKK>
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
__global__ void __launch_bounds__(WvPrGrp* THRDS)
    __attribute__((amdgpu_waves_per_eu(1, 1)))
    wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
                const int By, const scalar_t* __restrict__ B,
                const scalar_t* __restrict__ A,
                const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C,
                const int CuCount) {
  // Use upper half of glbl buffer for atomic reduce counting
  int* cntr = (int*)(&glbl[M * N]);

  constexpr int NTILE = 16;
  constexpr int APAD = 1;
  constexpr int ASTRD = 64;
  constexpr int BPAD = 1;
1388
1389
  constexpr int WVLDS_ = THRDS * A_CHUNK / CHUNKK;
  constexpr int WVLDS = ((WVLDS_ + A_CHUNK * BPAD)) * YTILE;
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442

  constexpr int max_lds_len = LDS_SIZE / 2;

  using scalar16 =
      __attribute__((__vector_size__((A_CHUNK * 2) * sizeof(float)))) float;
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
  using half4 =
      __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
  union bigType {
    scalar_t h[A_CHUNK];
    float f[A_CHUNK / 2];
    unsigned int i[A_CHUNK / 2];
    float2 f2[A_CHUNK / 4];
    unsigned long l[A_CHUNK / 4];
    double d[A_CHUNK / 4];
    half4 h4[A_CHUNK / 4];
    scalar8 h8;
  };
  using big4 = __attribute__((__vector_size__(4 * sizeof(bigType)))) __bf16;

  __shared__ scalar_t stg[WvPrGrp * WVLDS / GrpsShrB];
  unsigned int* myStg = (unsigned int*)(&stg[WVLDS * (threadIdx.y / GrpsShrB)]);
  __shared__ scalar_t s[max_lds_len - WvPrGrp * WVLDS / GrpsShrB];

  #ifndef WVSPLITKRC_1KPASS
  constexpr int TUC_ = (THRDS * UNRL * A_CHUNK);
  // find biggest k size that fits padded into LDS
  constexpr uint32_t kFit__ = (max_lds_len - WvPrGrp * WVLDS / GrpsShrB) / N;
  constexpr uint32_t kFit_ = (kFit__ * ASTRD) / (APAD + ASTRD);
  uint32_t kFit = kFit_ - (kFit_ % TUC_);
  uint32_t kfitsPerRdc = (K + kFit - 1) / kFit;

  // find best k split to fill the CUs
  if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <=
      CuCount)
    while (true) {
      while (kFit > TUC_) {
        uint32_t kFit_ = kFit - TUC_;
        if (((K + (kfitsPerRdc * kFit_ - 1)) / (kfitsPerRdc * kFit_)) *
                numCuWithFullK >
            CuCount)
          break;
        kFit = kFit_;
      }
      if (((K + ((kfitsPerRdc - 1) * kFit - 1)) / ((kfitsPerRdc - 1) * kFit)) *
              numCuWithFullK <=
          CuCount)
        kfitsPerRdc--;
      else
        break;
    }
  #else
1443
  int constexpr kFit = 512 / CHUNKK;
1444
1445
1446
  int constexpr kfitsPerRdc = 1;
  #endif

1447
  bool doRdc = true;  // Assuming (kfitsPerRdc * kFit < K) is always true
1448
1449
1450
1451
1452
  uint32_t numCuWithFullK =
      ((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB));
  uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB);

  // given above k-split, find this wave's position
1453
  uint32_t kFitPdd = kFit * CHUNKK + ((kFit * CHUNKK) / ASTRD) * APAD;
1454
1455
1456
1457
1458
1459
1460
  uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE;
  uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE;
  uint32_t m = (m0 + m1) % Mmod;
  const uint32_t k_str = (m0 / Mmod) * kFit * kfitsPerRdc;
  uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
  const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc);

1461
1462
  scalar8 sum4[N / NTILE / GrpsShrB][1] = {0};
  bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL];
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
  const uint32_t bLoader = (threadIdx.y % GrpsShrB);
  uint32_t kBase = 0;
  if (k_str >= K) return;
  if (m >= Mmod) return;

  bool noreloada = false;
  constexpr bool FAST_UNSAFE_RDC_INIT = false;

  #ifdef WVSPLITKRC_1KPASS
  // Early glbl init, B[] loading, if 1KPASS
  if constexpr (FAST_UNSAFE_RDC_INIT) {
    if (m + (threadIdx.x % 16) < M)
      if (doRdc)
        if (k_str == 0) {
          int mindx = m + (threadIdx.x % 16);
          int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
                       (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
          int adr_ = mindx + M * nindx_ / 4;
          __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED,
                             __HIP_MEMORY_SCOPE_AGENT);
          for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
            for (uint32_t j = 0; j < 4; j++) {
              int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
                          (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
              int adr = mindx + M * nindx;
              __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED,
                                 __HIP_MEMORY_SCOPE_AGENT);
            }
          }
        }
  }

    // Load first B[] chunk
    #pragma unroll
  for (uint32_t k2 = 0; k2 < UNRL; k2++) {
    uint32_t k = k_str + k2 * THRDS * A_CHUNK;
1499
    uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
1500
1501
    const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
    #pragma unroll
1502
1503
1504
1505
1506
1507
    for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
      bigB_[y / CHUNKK][k2].h8 = (loadnt(
          (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB +
                                   bLoader + m,
                               M - 1) *
                         K])));
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
  }
  {
  #else
  while (m < Mmod) {
  #endif

  #ifndef WVSPLITKRC_1KPASS
    if constexpr (FAST_UNSAFE_RDC_INIT) {
      if (m + (threadIdx.x % 16) < M)
        if (doRdc)
          if (k_str == 0) {
            int mindx = m + (threadIdx.x % 16);
            int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
                         (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
            int adr_ = mindx + M * nindx_ / 4;
            __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED,
                               __HIP_MEMORY_SCOPE_AGENT);
            for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
              for (uint32_t j = 0; j < 4; j++) {
                int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
                            (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
                int adr = mindx + M * nindx;
                __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED,
                                   __HIP_MEMORY_SCOPE_AGENT);
              }
            }
          }
    }

  #endif

  #ifndef WVSPLITKRC_1KPASS
    for (uint32_t k1 = k_str; k1 < k_end; k1 += THRDS * A_CHUNK * UNRL) {
  #else
    const uint32_t k1 = k_str;
    {
  #endif
  #ifndef WVSPLITKRC_1KPASS
      const bool reloada = (!noreloada) &&
                           ((k1 == k_str) || (k1 == k_str + kBase + kFit)) &&
                           (k1 < k_end);
      // load next chunk of A[] to LDS
      if (reloada) {
        if (k1 != k_str) kBase += kFit;
        __syncthreads();
  #else
      const bool reloada = (!noreloada) &&
                           ((k1 == k_str) || (k1 == k_str + kBase + kFit)) &&
                           (k1 < k_end);
      if (reloada) {
  #endif
        constexpr int sprdN = 4;
1560
        const uint32_t thrd = threadIdx.x % (THRDS / CHUNKK);
1561
1562
1563

  #ifndef WVSPLITKRC_1KPASS
    #pragma unroll
1564
1565
        for (int k = 0; k < kFit;
             k += (THRDS * (WvPrGrp / sprdN) * A_CHUNK) / CHUNKK) {
1566
1567
1568
1569
1570
  #else
        const unsigned int k = 0;
        {
  #endif
          unsigned int kOff = k + (thrd * A_CHUNK);
1571
          unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
1572
          for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
1573
            __builtin_amdgcn_global_load_lds(
1574
1575
1576
1577
1578
1579
1580
1581
                (int*)(&A[min__(
                    K * actlN - A_CHUNK,
                    kOffcp + K * (n / CHUNKK +
                                  (N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
                                  (threadIdx.y % sprdN)))]),
                (int*)(&s[(k +
                           kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
                16, 0, 0);
1582
1583
1584
1585
1586
          }

          // Stage loaded B[] to LDS for MFMA swizzling...
          for (uint32_t k2 = 0; k2 < UNRL; k2++) {
            uint32_t k = k1 + k2 * THRDS * A_CHUNK;
1587
            uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
1588
            const bool oob_k = (k_ >= K);
1589
1590
1591
1592
1593
            for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) {
              uint32_t idx =
                  (threadIdx.x % (THRDS / CHUNKK)) * 4 +
                  ((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader) *
                      ((THRDS / CHUNKK + BPAD) * 4);
1594
1595
              // zero out if oob
              *((scalar8*)&myStg[idx]) =
1596
                  (oob_k)  // TODO: ever necessary (y*GrpsShrB+bLoader+m>=M) ?
1597
                      ? 0
1598
                      : bigB_[y / CHUNKK][k2].h8;
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
            }
          }
        }
      }
    }
  #ifndef WVSPLITKRC_1KPASS
    // Fire load of next B[] chunk...
    if ((k1 + THRDS * A_CHUNK * UNRL < k_end) &&
        (k1 + THRDS * A_CHUNK * UNRL < K))
    #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + THRDS * A_CHUNK * UNRL + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
    #pragma unroll
1614
1615
1616
1617
1618
1619
1620
        for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
          bigB_[y / CHUNKK][k2].h8 = (loadnt(
              (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) *
                                           GrpsShrB +
                                       bLoader + m,
                                   M - 1) *
                             K])));
1621
1622
1623
1624
      }
  #endif

    // B[] staging is cooperative across GrpsShrB, so sync here before reading
1625
1626
    // back. This wait is currently inserted by compiler, but not gauranteed.
    asm volatile("s_waitcnt 0");
1627
1628
1629
    __syncthreads();

    // read back B[] swizzled for MFMA...
1630
    bigType bigB[YTILE / CHUNKK][UNRL];
1631
    for (uint32_t k2 = 0; k2 < UNRL; k2++) {
1632
1633
1634
1635
      for (uint32_t y = 0; y < YTILE / CHUNKK; y++) {
        unsigned int idx =
            (threadIdx.x % YTILE) * ((THRDS / CHUNKK + BPAD) * 4) +
            (threadIdx.x / YTILE) * 4 + y * 16;
1636
1637
1638
1639
1640
        bigB[y][k2].h8 = *((scalar8*)&myStg[idx]);
      }
    }

    // rReadback A[] swizzled for MFMA...
1641
    bigType bigA[N / GrpsShrB / CHUNKK][UNRL];
1642
1643
1644
1645
1646
1647
  #pragma unroll
    for (uint32_t k2 = 0; k2 < UNRL; k2++) {
      uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str;
  #pragma unroll
      for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE)
  #pragma unroll
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
        for (uint32_t n = 0; n < NTILE / CHUNKK; n++) {
          uint32_t idxa =
              ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) % (N / CHUNKK) +
               (threadIdx.x % NTILE)) *
                  kFitPdd +
              ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) /
               (N / CHUNKK)) *
                  A_CHUNK * (64 / CHUNKK) +
              A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
          bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa])));
1658
1659
1660
1661
1662
1663
1664
1665
1666
        }
    }

    // Do the MFMAs
  #pragma unroll
    for (uint32_t k2 = 0; k2 < UNRL; k2++) {
  #pragma unroll
      for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
  #pragma unroll
1667
        for (uint32_t j = 0; j < YTILE / CHUNKK; j++) {
1668
          if constexpr (std::is_same_v<scalar_t, half>) {
1669
1670
1671
            sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_f16(
                bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
                sum4[nt][0], 0, 0, 0);
1672
          } else {  // bf16
1673
1674
1675
            sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
                bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
                sum4[nt][0], 0, 0, 0);
1676
1677
1678
1679
1680
1681
          }
        }
      }
    }
  }

1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
  if (m + (threadIdx.x % 16) < M) {
    int my_cntr;
    int mindx = m + (threadIdx.x % 16);
    int g_mindx = m * 4 + (threadIdx.x % 64);  // coalesced atomic reduction
    scalar_t biases[N / NTILE / GrpsShrB][4] = {};
    // Atomic add the output, read biases
    for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
      for (uint32_t j = 0; j < 4; j++) {
        // int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
        //             (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
        // int adr = mindx + M * nindx;
        int g_nindx =
            j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
        int g_adr = g_mindx + M * g_nindx * 4;
        atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
      }
    int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
                 (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
    int adr_ = mindx + M * nindx_ / 4;
    // Update the complete counter
    my_cntr = atomicAdd(&cntr[adr_], 1);
    float vals[N / NTILE / GrpsShrB][4] = {};
    // If we're the last k-shard, read back the value and convert...
    if (my_cntr + 1 == k_rnd) {
1706
1707
1708
1709
1710
      if (BIAS)
        for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
          for (uint32_t j = 0; j < 4; j++) {
            int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
                        (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
1711
            biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
1712
1713
1714
1715
          }
        }
      for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
        for (uint32_t j = 0; j < 4; j++) {
1716
1717
1718
1719
          int g_nindx =
              j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
          int g_adr = g_mindx + M * g_nindx * 4;
          vals[nt][j] = glbl[g_adr];
1720
1721
        }
      }
1722
1723
1724
1725
1726
1727
      __builtin_amdgcn_sched_barrier(0);
      for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
        for (uint32_t j = 0; j < 4; j++) {
          int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
                      (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
          if (nindx < actlN) {
1728
            int adr = mindx + M * nindx;
1729
1730
1731
1732
1733
1734
            if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
              vals[nt][j] += __bfloat162float(biases[nt][j]);
              C[adr] = __float2bfloat16(vals[nt][j]);
            } else {
              vals[nt][j] += __half2float(biases[nt][j]);
              C[adr] = __float2half(vals[nt][j]);
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
            }
          }
        }
      }
    }

  #ifndef WVSPLITKRC_1KPASS
    m0 += CuCount * WvPrGrp * YTILE / GrpsShrB;
    m = (m0 + m1) % Mmod;
    k_str = (m0 / Mmod) * kFit * kfitsPerRdc;
    k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
    if (k_str >= K) break;
    kBase = 0;
  #endif
  }
}
#else   // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
1753
          int UNRL, int N, int GrpsShrB, int CHUNKK>
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
                            const int Bx, const int By, const scalar_t* B,
                            const scalar_t* __restrict__ A,
                            const scalar_t* __restrict__ BIAS, float* glbl,
                            // int* cntr,
                            scalar_t* C, const int CuCount){UNREACHABLE_CODE}
#endif  // defined(__HIP__GFX9__) TODO: Add NAVI support

torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
                         const std::optional<at::Tensor>& in_bias,
                         const int64_t CuCount) {
  auto M_in = in_a.size(0);
  auto N_in = in_b.size(0);
  auto K_in = in_a.size(1);
  auto Bx_in =
      (in_bias.has_value() && in_bias->numel() > 0)
          ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
          : 1;
  auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
                in_bias->sizes().size() == 2)
                   ? in_bias->size(0)
                   : 1;

  TORCH_CHECK(in_a.dtype() == in_b.dtype());
  TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
  TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
              in_a.dtype() == torch::kBFloat16);

  auto out_c = torch::empty(
      {N_in, M_in},
      torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));

  auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
  auto axl_glbl = torch::empty(
      {N_p2 + N_p2 / 4, M_in + M_in / 4},
      torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
  axl_glbl.zero_();  // disable for FAST_UNSAFE_RDC_INIT

  dim3 grid(CuCount);

  const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  // const int max_lds_len = get_lds_size() / 2;

1798
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK)                                     \
1799
  {                                                                            \
1800
1801
    dim3 block(64, 4);                                                         \
    wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK>               \
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
        <<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
                                     biasf4, glbl, c, CuCount);                \
  }

  AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
    using fptype = typename scalar<scalar_t>::type;
    fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
    const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
    const fptype* biasf4 =
        (in_bias.has_value() && in_bias->numel() > 0)
            ? reinterpret_cast<const fptype*>(in_bias->data_ptr())
            : nullptr;
    fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
    auto glbl = axl_glbl.data_ptr<float>();
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833

    // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    // and each working on a 512-shard of K, how many CUs would we need?
    int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);

    // How many of 4 waves in a group can work on same 16 Ms at same time? First
    // try to maximize this. This reduces the Ms each group works on, i.e.
    // increasing the number of CUs needed.
    int GrpsShrB = min(N_p2 / 16, 4);

    // Given the above, how many CUs would we need?
    int CuNeeded = rndup_cus * GrpsShrB;

    if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");

    // Can we increase SplitK by shrinking the K-shared to 256?
    int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;

1834
1835
    switch (N_p2) {
      case 16:
1836
        WVSPLITKrc(16, 1, 1) break;
1837
      case 32:
1838
1839
        if (chunkk == 2)
          WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
1840
      case 64:
1841
1842
        if (chunkk == 2)
          WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
1843
      case 128:
1844
1845
1846
        if (chunkk == 2)
          WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
              WVSPLITKrc(128, 4, 1) break;
1847
1848
1849
1850
1851
1852
1853
1854
1855
      default:
        throw std::runtime_error(
            "Unsupported N value: " + std::to_string(M_in) + "," +
            std::to_string(K_in) + "," + std::to_string(N_in));
    }
  });
  return out_c;
}

1856
#if defined(__HIP__MI3XX__)  // TODO: Add NAVI support
1857
1858
1859
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
          int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
1860
1861
1862
    wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
                      const int Bx, const int By, const fp8_t* B,
                      const fp8_t* __restrict__ A,
1863
                      const scalar_t* __restrict__ BIAS, scalar_t* C,
1864
1865
1866
                      const float* __restrict__ s_A,
                      const float* __restrict__ s_B, const int _WvPrGrp,
                      const int CuCount) {
1867
  constexpr int max_lds_len = LDS_SIZE;
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
  using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
  using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
  union bigType {
    char f8[A_CHUNK];
    char2 c2[A_CHUNK / 2];
    scalar_t h[A_CHUNK / 2];
    float f[A_CHUNK / 4];
    int i[A_CHUNK / 4];
    long l[A_CHUNK / 8];
    intx4 l2[A_CHUNK / 16];
    scalar8 h8;
  };

1883
  __shared__ fp8_t s[max_lds_len];
1884
1885

  for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
1886
1887
1888
1889
       k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
  #if defined(__gfx950__)
    __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
  #else
1890
    *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
1891
  #endif
1892
  }
1893
  asm volatile("s_waitcnt vmcnt(0)");
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
  __syncthreads();

  if (threadIdx.y >= _WvPrGrp) return;

  uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;

  using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
  float sA = *s_A;
  float sB = *s_B;

  while (m < M) {
1905
    floatx16 sum[N][YTILE] = {};
1906
    for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
1907
1908
      bigType bigA[N][UNRL] = {};
      bigType bigB[YTILE][UNRL];
1909
1910
1911
1912
1913
1914

      // Fetch the weight matrix from memory!
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
1915
        const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
1916
1917
  #pragma unroll
        for (uint32_t y = 0; y < YTILE; ++y) {
1918
          bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
        }
      }

  // Fetch activation matrix from either just LDS or from both LDS / memory
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;
        for (int n = 0; n < N; n++) {
1929
          bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
        }
      }

  // Do the matrix multiplication in interleaved manner
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        for (uint32_t n = 0; n < N; n++) {
          for (int i = 0; i < A_CHUNK; i += 8) {
            for (int y = 0; y < YTILE; ++y) {
              sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
                  bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
                  0);
            }
          }
        }
      }
    }

    // Final reduction
    for (int n = 0; n < N; n++) {
      for (int y = 0; y < YTILE; y++) {
        float accm0 = sum[n][y][0];
        float accm16 = sum[n][y][8];
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
                                          1);  // row_shl1
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
                                          1);  // row_shl2
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
                                          1);  // row_shl3
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
                                          1);  // row_shl8
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
                                          1);  // row_shl9
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
                                          1);  // row_shl10
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
                                          1);  // row_shl11
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
1974
1975
1976
1977
1978
1979
1980
        accm0 += __shfl(accm0, 36);
        accm16 += __shfl(accm16, 52);
        sum[n][y][0] = accm0 + __shfl(accm16, 16);
      }
    }

    if (threadIdx.x == 0) {
1981
1982
1983
1984
1985
1986
1987
      scalar_t biases[N][YTILE] = {};
      if (BIAS)
        for (int n = 0; n < N; n++) {
          for (int y = 0; y < YTILE; y++) {
            biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
          }
        }
1988
1989
      for (int n = 0; n < N; n++) {
        for (int y = 0; y < YTILE; y++) {
1990
1991
1992
          if (y + m >= M) break;  // To avoid mem access fault.
          sum[n][y][0] *= sA * sB;
          if constexpr (std::is_same_v<scalar_t, half>) {
1993
            sum[n][y][0] += __half2float(biases[n][y]);
1994
          } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
1995
            sum[n][y][0] += __bfloat162float(biases[n][y]);
1996
          }
1997
          C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
1998
1999
2000
2001
2002
2003
2004
        }
      }
    }

    m += CuCount * _WvPrGrp * YTILE;
  }
}
2005
#else   // !defined(__HIP__MI3XX__) TODO: Add NAVI support
2006
2007
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
          int A_CHUNK, int UNRL, int N>
2008
2009
2010
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
                                  const int M, const int Bx, const int By,
                                  const fp8_t* B, const fp8_t* __restrict__ A,
2011
                                  const scalar_t* __restrict__ BIAS,
2012
2013
2014
2015
2016
                                  scalar_t* C, const float* __restrict__ s_A,
                                  const float* __restrict__ s_B,
                                  const int _WvPrGrp, const int CuCount) {
  UNREACHABLE_CODE
}
2017
#endif  // defined(__HIP__MI3XX__) TODO: Add NAVI support
2018

2019
#if defined(__HIP__MI3XX__)  // TODO: Add NAVI support
2020
2021
2022
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
          int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
2023
2024
2025
    wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
                  const int Bx, const int By, const fp8_t* B,
                  const fp8_t* __restrict__ A,
2026
                  const scalar_t* __restrict__ BIAS, scalar_t* C,
2027
2028
                  const float* __restrict__ s_A, const float* __restrict__ s_B,
                  const int _WvPrGrp, const int CuCount) {
2029
  constexpr int max_lds_len = LDS_SIZE;
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
  using scalar8 =
      __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
  using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
  using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
  union bigType {
    char f8[A_CHUNK];
    char2 c2[A_CHUNK / 2];
    scalar_t h[A_CHUNK / 2];
    float f[A_CHUNK / 4];
    int i[A_CHUNK / 4];
    long l[A_CHUNK / 8];
    intx4 l2[A_CHUNK / 16];
    scalar8 h8;
  };

2045
  __shared__ fp8_t s[max_lds_len];
2046
2047

  for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
2048
2049
2050
2051
       k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
  #if defined(__gfx950__)
    __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
  #else
2052
    *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
2053
  #endif
2054
  }
2055
  asm volatile("s_waitcnt vmcnt(0)");
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
  __syncthreads();

  if (threadIdx.y >= _WvPrGrp) return;

  uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;

  using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
  float sA = *s_A;
  float sB = *s_B;

  while (m < M) {
2067
    floatx16 sum[N][YTILE] = {};
2068
    for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
2069
2070
2071
      bigType bigA[N][UNRL] = {};
      bigType bigB[YTILE][UNRL];

2072
2073
2074
2075
2076
      // Fetch the weight matrix from memory!
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
2077
        const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
2078
        for (int y = 0; y < YTILE; ++y) {
2079
          bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
        }
      }

  // Fetch activation matrix from either just LDS or from both LDS / memory
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        uint32_t k = k1 + k2 * THRDS * A_CHUNK;
        uint32_t k_ = k + threadIdx.x * A_CHUNK;
        if (k_ >= K) break;
        for (int n = 0; n < N; n++) {
2090
2091
          if (k_ + Kap * n < max_lds_len)
            bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
2092
          else
2093
            bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
        }
      }

  // Do the matrix multiplication in interleaved manner
  #pragma unroll
      for (uint32_t k2 = 0; k2 < UNRL; k2++) {
        for (uint32_t n = 0; n < N; n++) {
          for (int i = 0; i < A_CHUNK; i += 8) {
            for (int y = 0; y < YTILE; ++y) {
              sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
                  bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
                  0);
            }
          }
        }
      }
    }

    // Final reduction
    for (int n = 0; n < N; n++) {
      for (int y = 0; y < YTILE; y++) {
        float accm0 = sum[n][y][0];
        float accm16 = sum[n][y][8];
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
                                          1);  // row_shl1
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
                                          1);  // row_shl2
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
                                          1);  // row_shl3
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
                                          1);  // row_shl8
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
                                          1);  // row_shl9
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
                                          1);  // row_shl10
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
                                          1);  // row_shl11
        accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
2138
2139
2140
2141
2142
2143
2144
        accm0 += __shfl(accm0, 36);
        accm16 += __shfl(accm16, 52);
        sum[n][y][0] = accm0 + __shfl(accm16, 16);
      }
    }

    if (threadIdx.x == 0) {
2145
2146
2147
2148
2149
2150
2151
      scalar_t biases[N][YTILE] = {};
      if (BIAS)
        for (int n = 0; n < N; n++) {
          for (int y = 0; y < YTILE; y++) {
            biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
          }
        }
2152
2153
2154
      for (int n = 0; n < N; n++) {
        for (int y = 0; y < YTILE; y++) {
          if (y + m >= M) break;  // To avoid mem access fault.
2155
2156
          sum[n][y][0] *= sA * sB;
          if constexpr (std::is_same_v<scalar_t, half>) {
2157
            sum[n][y][0] += __half2float(biases[n][y]);
2158
          } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
2159
            sum[n][y][0] += __bfloat162float(biases[n][y]);
2160
2161
          }
          C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
2162
2163
2164
2165
2166
2167
2168
        }
      }
    }

    m += CuCount * _WvPrGrp * YTILE;
  }
}
2169
#else   // !defined(__HIP__MI3XX__) TODO: Add NAVI support
2170
2171
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
          int A_CHUNK, int UNRL, int N>
2172
2173
2174
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
                              const int M, const int Bx, const int By,
                              const fp8_t* B, const fp8_t* __restrict__ A,
2175
2176
                              const scalar_t* __restrict__ BIAS, scalar_t* C,
                              const float* __restrict__ s_A,
2177
2178
2179
2180
                              const float* __restrict__ s_B, const int _WvPrGrp,
                              const int CuCount) {
  UNREACHABLE_CODE
}
2181
#endif  // defined(__HIP__MI3XX__) TODO: Add NAVI support
2182

2183
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
2184
               const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
2185
               const at::Tensor& scale_a, const at::Tensor& scale_b,
2186
2187
2188
2189
               const int64_t CuCount) {
  static c10::ScalarType kFp8Type = is_fp8_ocp()
                                        ? c10::ScalarType::Float8_e4m3fn
                                        : c10::ScalarType::Float8_e4m3fnuz;
2190
2191
2192
2193
2194
  auto M_in = in_b.size(0);
  auto K_in = in_b.size(1);
  auto N_in = in_a.size(0);
  auto Kap_in = in_a.stride(0);
  auto Kbp_in = in_b.stride(0);
2195
2196
2197
2198
2199
2200
2201
2202
2203
  auto Bx_in =
      (in_bias.has_value() && in_bias->numel() > 0)
          ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
          : 1;
  auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
                in_bias->sizes().size() == 2)
                   ? in_bias->size(0)
                   : 1;

2204
2205
2206
2207
2208
2209
2210
2211
  TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
  TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
  TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
              out_c.dtype() == torch::kBFloat16);

  dim3 grid(CuCount);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
2212
  const int max_lds_len = get_lds_size();
2213

2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N)             \
  {                                                                           \
    dim3 block(64, _WvPrGrp);                                                 \
    if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) {            \
      int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16));     \
      wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
          <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in,     \
                                       By_in, b_ptr, a_ptr, bias_ptr, c_ptr,  \
                                       s_a, s_b, __wvPrGrp, CuCount);         \
    } else {                                                                  \
      int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16));     \
      wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N>     \
          <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in,     \
                                       By_in, b_ptr, a_ptr, bias_ptr, c_ptr,  \
                                       s_a, s_b, __wvPrGrp, CuCount);         \
    }                                                                         \
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
  }

  AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
    using fptype = typename scalar<scalar_t>::type;
    auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
    auto s_a = scale_a.data_ptr<float>();
    auto s_b = scale_b.data_ptr<float>();
    VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
      auto a_ptr = in_a.data_ptr<fp8_t>();
      auto b_ptr = in_b.data_ptr<fp8_t>();
2240
2241
2242
      auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
                          ? reinterpret_cast<fptype*>(in_bias->data_ptr())
                          : nullptr;
2243
2244
      switch (N_in) {
        case 1:
2245
          WVSPLITKQ(12, 2, 2, 2, 2, 1)
2246
2247
          break;
        case 2:
2248
          WVSPLITKQ(12, 2, 2, 2, 2, 2)
2249
2250
          break;
        case 3:
2251
          WVSPLITKQ(8, 2, 2, 1, 1, 3)
2252
2253
          break;
        case 4:
2254
          WVSPLITKQ(4, 2, 2, 1, 1, 4)
2255
2256
2257
2258
2259
2260
2261
2262
          break;
        default:
          throw std::runtime_error(
              "Unsupported N value: " + std::to_string(M_in) + "," +
              std::to_string(K_in) + "," + std::to_string(N_in));
      }
    });
  });
2263
}