swizzle.cu 41.9 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>

#include <cassert>
#include <numeric>
#include <type_traits>

#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"

18
namespace transformer_engine {
19
20
namespace {

21
22
23
constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16;

24
25
26
27
28
29
30
31
32
#ifdef __HIP_PLATFORM_AMD__
constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;

// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;
#else
33
34
35
constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
36
37

// output is in ~K-major interleaved blocks
38
39
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
40
#endif
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
  // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
  // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15]

  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
  int32_t new_regs[kVectorSize];
  int32_t* regs = reinterpret_cast<int32_t*>(regs_vec);

#pragma unroll
  for (int i = 0; i < N_TILE_PER_TD; i++) {
#pragma unroll
    for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) {
      new_regs[i * N_SF_PER_TD_PER_TILE + j] =
          (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) |
          (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) |
          (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) |
          (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24);
    }
  }
#pragma unroll
  for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i];
}

template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
68
69
70
71
72
__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M,
                                                const int K, const int original_M,
                                                const int original_K, const int bid_x,
                                                const int bid_y, const int grid_dim_x,
                                                const int grid_dim_y) {
73
74
75
76
77
78
79
80
81
82
83
84
85
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;

  // input is in M-major
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
  constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;

  const int M_i32 = M / 4;
  const int K_i32 = K;

  int m_tiles_in_tb = N_TILE_PER_TD;
  int k_tiles_in_tb = TB_DIM;
86
  if (bid_x == grid_dim_x - 1) {
87
88
    k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
  }
89
  if (bid_y == grid_dim_y - 1) {
90
91
92
    m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
  }

93
94
95
96
97
98
  bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
  bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);

  const int input_offset =
      bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
  const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + input_offset;
99
100
101
  int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
102
103
    output_i32[i] = reinterpret_cast<int32_t*>(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 +
                    (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
104
105
106
107
108
109
110
111
112
  }
  extern __shared__ int slm[];

  // load, global -> regs
  LType regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
      threadIdx.y < k_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
113
114
115
116
117
118
119
120
121
122
123
124
      const int thread_offset =
          (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD;
      regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
      // Pad zeros
      if (padding_m || padding_k) {
        for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
          const int index = (input_offset + thread_offset) * sizeof(int) + j;
          if (index / M >= original_K || index % M >= original_M) {
            reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
          }
        }
      }
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    }

    // local shuffle
    regs_shuffle_with_bit_shifts(regs_vec);

    // store, regs -> shared
    int tM = threadIdx.x * N_SF_PER_TD;
    int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
                           tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD; i++) {
      /* TODO rotate_i */
      slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
               ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
          reinterpret_cast<int*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
  for (int i = 0; i < m_tiles_in_tb; i++) {
    __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
    __align__(16) int4* slm_v4i =
        reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
    for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
         j += blockDim.x * blockDim.y) {
      output_v4i[j] = slm_v4i[j];
    }
  }
}

159
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
160
161
162
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
    swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K,
                               const int original_M, const int original_K) {
163
164
  swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
      input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
yuguo's avatar
yuguo committed
165
166
}

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) {
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  if constexpr (N_TILE_PER_TD == 1) return;

  constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
  int32_t tmp[kVectorSize];
  int32_t* ptr = reinterpret_cast<int32_t*>(regs_vec);
#pragma unroll
  for (int i = 0; i < kVectorSize; i++)
    tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i];

#pragma unroll
  for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i];
}

template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
184
185
186
187
188
__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M,
                                                const int K, const int original_M,
                                                const int original_K, const int bid_x,
                                                const int bid_y, const int grid_dim_x,
                                                const int grid_dim_y) {
189
190
191
192
193
194
195
196
197
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;

  // input is in K-major
  constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
  constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;

  int n_tiles_in_tb = N_TILES_IN_TB;
  const int K_i32 = K / 4;
198
  if (bid_x == grid_dim_x - 1) {
199
200
201
    n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
  }

202
203
204
205
206
207
208
  bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
  bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);

  const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB;
  const int* input_i32 = reinterpret_cast<const int*>(input) + input_offset;
  int* output_i32 = reinterpret_cast<int*>(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 +
                    bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
209
210
211
212
213
214
215
216

  extern __shared__ int4 slm_v4i[];

  // load, global -> regs
  LType regs_vec[N_SF_PER_TD_PER_TILE];
  if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
    for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
217
218
219
220
221
222
223
224
225
226
227
      const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD;
      regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
      if (padding_m || padding_k) {
        // Pad zeros
        for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
          const int index = (input_offset + thread_offset) * sizeof(int) + j;
          if (index / K >= original_M || index % K >= original_K) {
            reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
          }
        }
      }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    }

    // shuffle regs
    regs_shuffle<LType>(regs_vec);

// store, regs -> shared
#pragma unroll
    for (int i = 0; i < N_TILE_PER_TD; i++) {
      /* TODO rotate i */
      slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
          reinterpret_cast<int4*>(regs_vec)[i];
    }
  }
  __syncthreads();

  // store, shared -> global
  int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
  __align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
  for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
    output_v4i[i] = slm_v4i[i];
  }
}

252
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
253
254
255
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
    swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K,
                               const int original_M, const int original_K) {
256
257
258
  swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
      input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
yuguo's avatar
yuguo committed
259

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
constexpr int kMaxTensorsPerKernel = 64;  // Args must be <4 KB
struct MultiSwizzleArgs {
  // (input) Data buffers for input scaling factors
  void* input_list[kMaxTensorsPerKernel];
  // (output) Data buffers for swizzled scaling factors
  void* output_list[kMaxTensorsPerKernel];
  // Input scaling factor m
  int m_list[kMaxTensorsPerKernel];
  // Input scaling factor k
  int k_list[kMaxTensorsPerKernel];
  // Input scaling factor m before padding
  int original_m_list[kMaxTensorsPerKernel];
  // Input scaling factor k before padding
  int original_k_list[kMaxTensorsPerKernel];
  // Prefix sum (with leading zero) of CUDA blocks needed for each
  // tensor
  int block_range[kMaxTensorsPerKernel + 1];
  // Number of tensors being processed by kernel
  int num_tensors;
};
yuguo's avatar
yuguo committed
280

281
282
283
284
285
286
287
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) {
  // Find tensor corresponding to block
  const int bid = blockIdx.x;
  int tensor_id = 0;
  while (kernel_args.block_range[tensor_id + 1] <= bid) {
    ++tensor_id;
yuguo's avatar
yuguo committed
288
  }
289
290
291
292
293
294
295
  // Get args corresponding to block
  const void* input = kernel_args.input_list[tensor_id];
  void* output = kernel_args.output_list[tensor_id];
  const int M = kernel_args.m_list[tensor_id];
  const int K = kernel_args.k_list[tensor_id];
  const int original_M = kernel_args.original_m_list[tensor_id];
  const int original_K = kernel_args.original_k_list[tensor_id];
yuguo's avatar
yuguo committed
296

297
298
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
  constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
yuguo's avatar
yuguo committed
299

300
301
302
303
304
305
306
  // Get block index in grid. Emulate 2D grid.
  const int num_tiles_k = K / SF_TILE_DIM_K;
  const int num_tiles_m = M / SF_TILE_DIM_M;
  const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB);
  const int grid_dim_y = num_tiles_m;
  const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
  const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
yuguo's avatar
yuguo committed
307

308
309
310
  swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
      input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
}
yuguo's avatar
yuguo committed
311

312
313
314
315
316
317
318
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) {
  // Find tensor corresponding to block
  const int bid = blockIdx.x;
  int tensor_id = 0;
  while (kernel_args.block_range[tensor_id + 1] <= bid) {
    ++tensor_id;
yuguo's avatar
yuguo committed
319
  }
320
321
322
323
324
325
326
  // Get args corresponding to block
  const void* input = kernel_args.input_list[tensor_id];
  void* output = kernel_args.output_list[tensor_id];
  const int M = kernel_args.m_list[tensor_id];
  const int K = kernel_args.k_list[tensor_id];
  const int original_M = kernel_args.original_m_list[tensor_id];
  const int original_K = kernel_args.original_k_list[tensor_id];
yuguo's avatar
yuguo committed
327

328
329
330
331
332
333
334
335
336
337
338
339
  constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);

  // Get block index in grid. Emulate 2D grid.
  const int num_tiles_k = K / SF_TILE_DIM_K;
  const int num_tiles_m = M / SF_TILE_DIM_M;
  const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM);
  const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD);
  const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
  const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;

  swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
      input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
yuguo's avatar
yuguo committed
340
}
341

342
}  // namespace
343
344

void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
345
346
347
348
  // Check scaling mode
  const auto& scaling_mode = input->scaling_mode;
  NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
             "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
349

350
  // Check tensors
351
352
  CheckInputTensor(*input, "scaling_factor_input");
  CheckInputTensor(*output, "scaling_factor_output");
353
354
355
356
  NVTE_CHECK(!input->with_gemm_swizzled_scales,
             "Expected input tensor with scales in compact format.");
  NVTE_CHECK(output->with_gemm_swizzled_scales,
             "Expected output tensor with scales in GEMM swizzled format.");
357
358
359
360
361
362
363
364
365
366
367
368
  switch (scaling_mode) {
    case NVTE_MXFP8_1D_SCALING:
      NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
                 to_string(input->dtype()), ").");
      break;
    case NVTE_NVFP4_1D_SCALING:
      NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
                 to_string(input->dtype()), ").");
      break;
    default:
      NVTE_ERROR("Invalid scaling mode");
  }
369

370
371
372
373
374
375
376
377
  // Check if scaling factors are non-trivial
  const bool has_rowwise_scale_inv = input->scale_inv.has_data();
  const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
  NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
             "Input tensor has both row-wise and column-wise scaling factors");
  if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
    return;
  }
378

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
  // Deduce tensor dims
  int m{0}, k{0};
  switch (scaling_mode) {
    case NVTE_MXFP8_1D_SCALING: {
      if (has_rowwise_scale_inv) {
        NVTE_CHECK(input->scale_inv.shape.size() == 2,
                   "Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
        m = input->scale_inv.shape[0];
        k = input->scale_inv.shape[1];
      } else if (has_columnwise_scale_inv) {
        NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
                   "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
                   ".");
        m = input->columnwise_scale_inv.shape[1];
        k = input->columnwise_scale_inv.shape[0];
      }
      break;
    }
    case NVTE_NVFP4_1D_SCALING: {
      if (has_rowwise_scale_inv) {
        NVTE_CHECK(input->scale_inv.shape.size() == 2,
                   "Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
        m = input->scale_inv.shape[0];
        k = input->scale_inv.shape[1];
      } else if (has_columnwise_scale_inv) {
        NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
                   "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
                   ".");
        m = input->columnwise_scale_inv.shape[0];
        k = input->columnwise_scale_inv.shape[1];
      }
      break;
411
    }
412
413
    default:
      NVTE_ERROR("Invalid scaling mode");
414
  }
415

416
  // Check dims
417
418
419
420
  constexpr int SF_TILE_DIM_M = 128;
  constexpr int SF_TILE_DIM_K = 4;
  NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
  NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
421
422
423
424
425
426
427

  // Check that output tensor matches input tensor
  if (has_rowwise_scale_inv) {
    NVTE_CHECK(output->scale_inv.has_data(),
               "Output tensor does not have row-wise scaling factors.");
    NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k,
               " row-wise scaling factors, but got shape=", output->scale_inv.shape, ".");
428
  }
429
430
431
432
433
434
  if (has_columnwise_scale_inv) {
    NVTE_CHECK(output->columnwise_scale_inv.has_data(),
               "Output tensor does not have column-wise scaling factors.");
    NVTE_CHECK(
        m * k == output->columnwise_scale_inv.numel(), "Expected output tensor to have ", m * k,
        " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, ".");
435
436
  }

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
  // Choose swizzle implementation
  bool rowwise_swizzle{false}, columnwise_swizzle{false};
  switch (scaling_mode) {
    case NVTE_MXFP8_1D_SCALING: {
      rowwise_swizzle = has_rowwise_scale_inv;
      columnwise_swizzle = has_columnwise_scale_inv;
      break;
    }
    case NVTE_NVFP4_1D_SCALING: {
      // NVFP4 column-wise data is transposed, so row-wise and
      // column-wise scales have same swizzling format
      rowwise_swizzle = true;
      columnwise_swizzle = false;
      break;
    }
    default:
      NVTE_ERROR("Invalid scaling mode");
  }
455

456
457
458
  const dim3 block_size(TB_DIM, TB_DIM);
  const int num_tiles_m = m / SF_TILE_DIM_M;
  const int num_tiles_k = k / SF_TILE_DIM_K;
459

460
  // Perform row-wise swizzle
461
462
463
464
465
466
467
468
  if (rowwise_swizzle) {
    int vec_load_size = (num_tiles_k - 1) % 4 + 1;
    /* there is no int3 and misaligned if using int4/int2 */
    if (vec_load_size == 3) vec_load_size = 1;
    int n_tiles_in_tb = TB_DIM * vec_load_size;
    dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
    int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    int original_M{0}, original_K{0};
    void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr};
    switch (scaling_mode) {
      case NVTE_MXFP8_1D_SCALING: {
        original_M = input->flat_first_dim();
        original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
        input_scale_inv_ptr = input->scale_inv.dptr;
        output_scale_inv_ptr = output->scale_inv.dptr;
        break;
      }
      case NVTE_NVFP4_1D_SCALING: {
        if (has_rowwise_scale_inv) {
          original_M = input->flat_first_dim();
          original_K = input->flat_last_dim() / NVFP4_BLOCK_SIZE;
          input_scale_inv_ptr = input->scale_inv.dptr;
          output_scale_inv_ptr = output->scale_inv.dptr;
        } else if (has_columnwise_scale_inv) {
          original_M = input->flat_last_dim();
          original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
          input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
          output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
        }
        break;
      }
      default:
        NVTE_ERROR("Invalid scaling mode");
495
    }
496
497

    switch (vec_load_size) {
yuguo's avatar
yuguo committed
498
#ifdef __HIP_PLATFORM_AMD__
wenjh's avatar
wenjh committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
      case 4:
        cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
      case 2:
        cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
      case 1:
        cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
yuguo's avatar
yuguo committed
520
#else
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
      case 4:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
      case 2:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
      case 1:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(
                input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
        break;
yuguo's avatar
yuguo committed
545
#endif
546
547
548
      default:
        NVTE_ERROR("Not valid vec_load_size.");
        break;
549
    }
550
    NVTE_CHECK_CUDA(cudaGetLastError());
551
  }
552
553

  // Perform column-wise swizzle
554
555
556
557
558
559
560
561
  if (columnwise_swizzle) {
    int vec_load_size = (num_tiles_m - 1) % 4 + 1;
    if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
    int n_tiles_in_tb = TB_DIM * vec_load_size;
    dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
    int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
    const int original_M = input->flat_last_dim();
    const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
562

563
    switch (vec_load_size) {
yuguo's avatar
yuguo committed
564
#ifdef __HIP_PLATFORM_AMD__
wenjh's avatar
wenjh committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
      case 4:
        cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                            output->columnwise_scale_inv.dptr, m,
                                                            k, original_M, original_K);
        break;
      case 2:
        cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                            output->columnwise_scale_inv.dptr, m,
                                                            k, original_M, original_K);
        break;
      case 1:
        cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                              cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                            output->columnwise_scale_inv.dptr, m,
                                                            k, original_M, original_K);
        break;
yuguo's avatar
yuguo committed
589
#else
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
      case 4:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                           output->columnwise_scale_inv.dptr, m, k,
                                                           original_M, original_K);
        break;
      case 2:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                           output->columnwise_scale_inv.dptr, m, k,
                                                           original_M, original_K);
        break;
      case 1:
        NVTE_CHECK_CUDA(
            cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
        swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
                                                           output->columnwise_scale_inv.dptr, m, k,
                                                           original_M, original_K);
        break;
yuguo's avatar
yuguo committed
617
#endif
618
619
620
      default:
        NVTE_ERROR("Not valid vec_load_size.");
        break;
621
    }
622
    NVTE_CHECK_CUDA(cudaGetLastError());
623
  }
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
}

template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
                                                 const int vec_load_size, const bool is_rowwise,
                                                 cudaStream_t stream) {
  int n_tiles_in_tb = TB_DIM * vec_load_size;
  int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
  /* Calculate number of CUDA blocks needed for each tensor.
    * We have to do it here because we have to iterate over all tensors in this batch to
    * get the minimum vec_load_size.
    */
  for (size_t j = 0; j < kernel_args.num_tensors; j++) {
    const int m = kernel_args.m_list[j];
    const int k = kernel_args.k_list[j];
    int num_tiles_m = m / SF_TILE_DIM_M;
    int num_tiles_k = k / SF_TILE_DIM_K;
    if (is_rowwise) {
      kernel_args.block_range[j + 1] =
          kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m;
    } else {
      kernel_args.block_range[j + 1] =
          kernel_args.block_range[j] +
          DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size);
    }
  }
  // Launch kernel
  const int num_blocks = kernel_args.block_range[kernel_args.num_tensors];
  dim3 block_size(TB_DIM, TB_DIM);
  if (is_rowwise) {
    switch (vec_load_size) {
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
#ifdef __HIP_PLATFORM_AMD__
      case 4:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 2:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 1:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
#else
678
      case 4:
679
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
680
            multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
681
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
682
683
684
685
        multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 2:
686
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
687
            multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
688
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
689
690
691
692
        multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 1:
693
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
694
            multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
695
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
696
697
698
        multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
699
#endif
700
701
702
703
704
705
      default:
        NVTE_ERROR("Not valid vec_load_size.");
        break;
    }
  } else {
    switch (vec_load_size) {
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
#ifdef __HIP_PLATFORM_AMD__
      case 4:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 2:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 1:
        cudaFuncSetAttribute(
            (const void *)multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
        multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
#else
729
      case 4:
730
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
731
            multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
732
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
733
734
735
736
        multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 2:
737
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
738
            multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
739
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
740
741
742
743
        multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
      case 1:
744
        NVTE_CHECK_CUDA(cudaFuncSetAttribute(
745
            multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
746
            cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
747
748
749
        multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
            <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
        break;
750
#endif
751
752
753
754
755
756
757
      default:
        NVTE_ERROR("Not valid vec_load_size.");
        break;
    }
  }
  NVTE_CHECK_CUDA(cudaGetLastError());
}
758

759
760
761
762
763
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
                                          std::vector<Tensor*>& output, cudaStream_t stream) {
  auto num_tensors = input.size();
  bool all_has_data = true;
  bool all_has_columnwise_data = true;
764
  bool all_nvfp4 = true;
765
  for (size_t i = 0; i < num_tensors; i++) {
766
767
768
769
770
771
    auto scaling_mode = input[i]->scaling_mode;
    auto is_fp8 = is_fp8_dtype(input[i]->dtype());
    auto is_fp4 = is_fp4_dtype(input[i]->dtype());
    NVTE_CHECK(
        (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
        "Not implemented scaling mode " + to_string(scaling_mode) + ".");
772
773
774
775
776
    NVTE_CHECK(!input[i]->with_gemm_swizzled_scales,
               "Expected input tensors with scales in compact format.");
    NVTE_CHECK(output[i]->with_gemm_swizzled_scales,
               "Expected output tensors with scales in GEMM swizzled format.");

777
    // We don't allow empty tensors. They should be filtered out before calling this function.
778
    NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty.");
779
780
    CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
    CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
781
782
783
784
    all_has_data = all_has_data && input[i]->scale_inv.has_data();
    all_has_columnwise_data =
        (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data());
    all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode);
785
786
787
  }
  NVTE_CHECK(all_has_data || all_has_columnwise_data,
             "All tensors should have data or columnwise data.");
788
789
  NVTE_CHECK(!all_has_data || !all_has_columnwise_data,
             "All tensors have both data and columnwise data.");
790

791
792
793
  const bool rowwise_swizzle = all_has_data || all_nvfp4;
  const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;

794
795
  constexpr int SF_TILE_DIM_M = 128;
  constexpr int SF_TILE_DIM_K = 4;
796
  if (rowwise_swizzle) {
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    MultiSwizzleArgs kernel_args;
    kernel_args.num_tensors = 0;
    kernel_args.block_range[0] = 0;
    int vec_load_size = 4;
    for (size_t i = 0; i < num_tensors; i++) {
      //Launch kernel if argument struct is full
      if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
        // There is no int3 and misaligned if using int4/int2.
        if (vec_load_size == 3) vec_load_size = 1;
        launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
            kernel_args, vec_load_size, true, stream);
        // Reset the argument struct and vec_load_size
        kernel_args.num_tensors = 0;
        vec_load_size = 4;
      }
812
813
814
815
816
817
818
819
820
821
822

      int m, k;

      if (all_has_data) {
        m = input[i]->scale_inv.shape[0];
        k = input[i]->scale_inv.shape[1];
      } else {
        NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4");
        m = input[i]->columnwise_scale_inv.shape[0];
        k = input[i]->columnwise_scale_inv.shape[1];
      }
823
824
825
826

      NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
      NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
      NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
827

828
829
830
831
832
833
      if (all_has_data) {
        NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i,
                   " does not have row-wise scaling factors.");
        NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ",
                   m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape,
                   ".");
834
      }
835
836
837
838
839
840
      if (all_has_columnwise_data) {
        NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i,
                   " does not have column-wise scaling factors.");
        NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i,
                   " to have ", m * k, " column-wise scaling factors, but got shape=",
                   output[i]->columnwise_scale_inv.shape, ".");
841
      }
842
843
844
845

      int num_tiles_k = k / SF_TILE_DIM_K;
      int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
      // We use the minimum vec_load_size across all tensors.
846
847
848
849
850
      // TODO(zhongbo): fix vec_load_size for NVFP4
      // Current unit test won't capture this issue, but in E2E
      // using vec_load_size = 1 other than 1 will lead to mis-aligned
      // address error in MOE training
      vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i);
851
852
853
854

      const int pos = kernel_args.num_tensors;
      kernel_args.m_list[pos] = m;
      kernel_args.k_list[pos] = k;
855
856
857
858
859
860
861
862
863
864
865
866
      if (!all_nvfp4 || all_has_data) {
        int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
        kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
        kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
        kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
        kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size;
      } else {
        kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
        kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
        kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
        kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE;
      }
867
868
869
870
871
872
873
874
875
      kernel_args.num_tensors++;
    }
    // Launch the remaining tensors
    // There is no int3 and misaligned if using int4/int2.
    if (vec_load_size == 3) vec_load_size = 1;
    launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
        kernel_args, vec_load_size, true, stream);
  }

876
877
878
879
  if (columnwise_swizzle) {
    // NVFP4 shouldn't end up here because it only needs rowwise swizzle
    NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    MultiSwizzleArgs kernel_args;
    kernel_args.num_tensors = 0;
    kernel_args.block_range[0] = 0;
    int vec_load_size = 4;
    for (size_t i = 0; i < num_tensors; i++) {
      //Launch kernel if argument struct is full
      if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
        // There is no int3 and misaligned if using int4/int2.
        if (vec_load_size == 3) vec_load_size = 1;
        launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
            kernel_args, vec_load_size, false, stream);
        // Reset the argument struct and vec_load_size
        kernel_args.num_tensors = 0;
        vec_load_size = 4;
      }
      const int m = input[i]->columnwise_scale_inv.shape[1];
      const int k = input[i]->columnwise_scale_inv.shape[0];

      NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
      NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
      NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
      NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
                                          output[i]->columnwise_scale_inv.shape.end(), 1,
                                          std::multiplies<int>()),
                 "Input.columnwise_scale_inv size is not equal to "
                 "Output.columnwise_scale_inv size!");

      int num_tiles_k = k / SF_TILE_DIM_K;
      int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
      // We use the minimum vec_load_size across all tensors.
      vec_load_size = std::min(vec_load_size, vec_load_size_i);

      const int pos = kernel_args.num_tensors;
      kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
      kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
      kernel_args.m_list[pos] = m;
      kernel_args.k_list[pos] = k;
      kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
      kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE;
      kernel_args.num_tensors++;
    }
    // Launch the remaining tensors
    // There is no int3 and misaligned if using int4/int2.
    if (vec_load_size == 3) vec_load_size = 1;
    launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
        kernel_args, vec_load_size, false, stream);
926
927
928
929
930
931
932
933
  }
}
}  // namespace transformer_engine

/*
 * WIP (Phuong):
 *   - Opt for bank conflicts
 *   - Adding swizzle for 2d-block scaling.
934
 */
935
936
937
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_swizzle_scaling_factors);
  using namespace transformer_engine;
938
  swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
939
}
940
941
942
943
944
945
946
947
948
949
950
951
952

void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
                                               const size_t num_tensors, cudaStream_t stream) {
  NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors);
  using namespace transformer_engine;
  NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
  std::vector<Tensor*> input_list, output_list;
  for (size_t i = 0; i < num_tensors; i++) {
    input_list.push_back(convertNVTETensorCheck(inputs[i]));
    output_list.push_back(convertNVTETensorCheck(outputs[i]));
  }
  multi_tensor_swizzle_scaling_factors(input_list, output_list, stream);
}