swizzle.cuh 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/*************************************************************************
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

/*! \file swizzle.cuh
 *  \brief Helper function for GEMM-swizzled scales
 */

#ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
#define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_

namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace swizzle {

/*! \brief Convert compact scale indices into GEMM swizzled scale index
 *
 *  MXFP8 GEMM expects scaling factors to be in a "swizzled" order
 *  (https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout).
 *  This function converts indices from "compact" order (i.e. matching
 *  the FP8 data) to swizzled order.
 *
 */
__device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) {
  constexpr size_t TILE_DIM_X = 4;  // Tile dim in scale buffer
  constexpr size_t TILE_DIM_Y = 128;
  constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y;
  const size_t tile_idx_X = j / TILE_DIM_X;
  const size_t tile_idx_Y = i / TILE_DIM_Y;
  const size_t idx_in_tile_X = j % TILE_DIM_X;
  const size_t idx_in_tile_Y = i % TILE_DIM_Y;
  size_t idx = (tile_idx_Y * num_tiles_X + tile_idx_X) * TILE_SIZE;
  idx += (idx_in_tile_Y % 32) * 16 + (idx_in_tile_Y / 32) * 4 + idx_in_tile_X;
  return idx;
}

}  // namespace swizzle
}  // namespace mxfp8
}  // namespace dispatch
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_