swizzle.cu 21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>

#include <cassert>
#include <numeric>
#include <type_traits>

#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"

namespace {

constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;

// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;

template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
  // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
  // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15]

  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
  int32_t new_regs[kVectorSize];
  int32_t* regs = reinterpret_cast<int32_t*>(regs_vec);

#pragma unroll
  for (int i = 0; i < N_TILE_PER_TD; i++) {
#pragma unroll
    for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) {
      new_regs[i * N_SF_PER_TD_PER_TILE + j] =
          (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) |
          (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) |
          (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) |
          (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24);
    }
  }
#pragma unroll
  for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i];
}

template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M,
                                           const int K) {
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;

  // input is in M-major
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
  constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;

  const int M_i32 = M / 4;
  const int K_i32 = K;

  int m_tiles_in_tb = N_TILE_PER_TD;
  int k_tiles_in_tb = TB_DIM;
  if (blockIdx.x == gridDim.x - 1) {
    k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
  }
  if (blockIdx.y == gridDim.y - 1) {
    m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
  }

  const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
                             blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
                             blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
  int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
    output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
                    (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
  }
  extern __shared__ int slm[];

  // load, global -> regs
  LType regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
      threadIdx.y < k_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
      regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
          input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD));
    }

    // local shuffle
    regs_shuffle_with_bit_shifts(regs_vec);

    // store, regs -> shared
    int tM = threadIdx.x * N_SF_PER_TD;
    int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
                           tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD; i++) {
      /* TODO rotate_i */
      slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
               ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
          reinterpret_cast<int*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
    __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
    __align__(16) int4* slm_v4i =
        reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
         j += blockDim.x * blockDim.y) {
      output_v4i[j] = slm_v4i[j];
    }
  }
}

yuguo's avatar
yuguo committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel_int(const void* input, void* output, const int M,
                                           const int K) {
  constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
  constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;

  // input is in M-major
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
  constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;

  const int M_i32 = M / 4;
  const int K_i32 = K;

  int m_tiles_in_tb = N_TILE_PER_TD;
  int k_tiles_in_tb = TB_DIM;
  if (blockIdx.x == gridDim.x - 1) {
    k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
  }
  if (blockIdx.y == gridDim.y - 1) {
    m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
  }

  const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
                             blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
                             blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
  int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
    output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
                    (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
  }
  extern __shared__ int slm[];

  // load, global -> regs
  int regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
      threadIdx.y < k_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
      regs_vec[i] = *reinterpret_cast<const int*>(
          input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD);
    }

    // local shuffle
    regs_shuffle_with_bit_shifts(regs_vec);

    // store, regs -> shared
    int tM = threadIdx.x * N_SF_PER_TD;
    int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
                           tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD; i++) {
      /* TODO rotate_i */
      slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
               ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
          reinterpret_cast<int*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
    __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
    __align__(16) int4* slm_v4i =
        reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
         j += blockDim.x * blockDim.y) {
      output_v4i[j] = slm_v4i[j];
    }
  }
}
#endif

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) {
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  if constexpr (N_TILE_PER_TD == 1) return;

  constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
  int32_t tmp[kVectorSize];
  int32_t* ptr = reinterpret_cast<int32_t*>(regs_vec);
#pragma unroll
  for (int i = 0; i < kVectorSize; i++)
    tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i];

#pragma unroll
  for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i];
}

template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M,
                                           const int K) {
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;

  // input is in K-major
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;

  int n_tiles_in_tb = N_TILES_IN_TB;
  const int K_i32 = K / 4;
  if (blockIdx.x == gridDim.x - 1) {
    n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
  }

  const int* input_i32 = reinterpret_cast<const int*>(input) +
                         blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
  int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
                    blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;

  extern __shared__ int4 slm_v4i[];

  // load, global -> regs
  LType regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
      regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
          input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD));
    }

    // shuffle regs
    regs_shuffle<LType>(regs_vec);

// store, regs -> shared
#pragma unroll
    for (int i = 0; i < N_TILE_PER_TD; i++) {
      /* TODO rotate i */
      slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
          reinterpret_cast<int4*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
  __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
  for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
    output_v4i[i] = slm_v4i[i];
  }
}

yuguo's avatar
yuguo committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel_int(const void* input, void* output, const int M,
                                           const int K) {
  constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
  constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;

  // input is in K-major
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;

  int n_tiles_in_tb = N_TILES_IN_TB;
  const int K_i32 = K / 4;
  if (blockIdx.x == gridDim.x - 1) {
    n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
  }

  const int* input_i32 = reinterpret_cast<const int*>(input) +
                         blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
  int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
                    blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;

  extern __shared__ int4 slm_v4i[];

  // load, global -> regs
  int regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
      regs_vec[i] = *reinterpret_cast<const int*>(
          input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD);
    }

    // shuffle regs
    regs_shuffle<int>(regs_vec);

// store, regs -> shared
#pragma unroll
    for (int i = 0; i < N_TILE_PER_TD; i++) {
      /* TODO rotate i */
      slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
          reinterpret_cast<int4*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
  __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
  for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
    output_v4i[i] = slm_v4i[i];
  }
}
#endif
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
}  // namespace

namespace transformer_engine {

void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
  if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
    NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + ".");
  }

  // Do nothing if tensor is empty
  if (input->data.numel() == 0) {
    return;
  }

  CheckInputTensor(*input, "scaling_factor_input");
  CheckInputTensor(*output, "scaling_factor_output");

  auto& scaling_mode = input->scaling_mode;

  // 1D block scaling, row-wise or colum-wise
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    const int m =
        input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1];
    const int k =
        input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0];

    constexpr int SF_TILE_DIM_M = 128;
    constexpr int SF_TILE_DIM_K = 4;

    NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
    NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
    NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
    if (output->has_data()) {
      NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
                                          output->scale_inv.shape.end(), 1, std::multiplies<int>()),
                 "Input.scale_inv size is not equal to Output.scale_inv size!");
    }
    if (output->has_columnwise_data()) {
      NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
                                          output->columnwise_scale_inv.shape.end(), 1,
                                          std::multiplies<int>()),
                 "Input.columnwise_scale_inv size is not equal to "
                 "Output.columnwise_scale_inv size!");
    }

    int num_tiles_m = m / SF_TILE_DIM_M;
    int num_tiles_k = k / SF_TILE_DIM_K;

    dim3 block_size(TB_DIM, TB_DIM);
    if (input->has_data()) {
      int vec_load_size = (num_tiles_k - 1) % 4 + 1;
      /* there is no int3 and misaligned if using int4/int2 */
      if (vec_load_size == 3) vec_load_size = 1;
      int n_tiles_in_tb = TB_DIM * vec_load_size;
      dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
      int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
      switch (vec_load_size) {
yuguo's avatar
yuguo committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
#ifdef __HIP_PLATFORM_AMD__
        case 4:
          cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
        case 2:
          cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
        case 1:
          cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
#else
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        case 4:
          cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
        case 2:
          cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
        case 1:
          cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
                                                             output->scale_inv.dptr, m, k);
          break;
yuguo's avatar
yuguo committed
433
#endif
434
435
436
437
438
439
440
441
442
443
444
445
        default:
          NVTE_ERROR("Not valid vec_load_size.");
          break;
      }
    }
    if (input->has_columnwise_data()) {
      int vec_load_size = (num_tiles_m - 1) % 4 + 1;
      if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
      int n_tiles_in_tb = TB_DIM * vec_load_size;
      dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
      int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
      switch (vec_load_size) {
yuguo's avatar
yuguo committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#ifdef __HIP_PLATFORM_AMD__
        case 4:
          cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
        case 2:
          cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
        case 1:
          cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
#else
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
        case 4:
          cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
        case 2:
          cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
        case 1:
          cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                               cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
          swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
              <<<num_blocks, block_size, slm_size, stream>>>(
                  input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
          break;
yuguo's avatar
yuguo committed
490
#endif
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        default:
          NVTE_ERROR("Not valid vec_load_size.");
          break;
      }
    }

    // 2D block scaling
  } else {
    NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
  }
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("CUDA Error: %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
}  // namespace transformer_engine

/*
 * WIP (Phuong):
 *   - Opt for bank conflicts
 *   - Adding swizzle for 2d-block scaling.
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_swizzle_scaling_factors);
  using namespace transformer_engine;
517
  swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
518
}