Unverified Commit c0d2f1a5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Multi-tensor swizzle scaling factors for MXFP8 and fuse padding zeros (#2019)



* for loop
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* bulk alloc
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* multi-tensor swizzle
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* pad zeros in swizzle kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* unify single- and multi-tensor swizzle
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix empty tensor list
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix bug for col swizzle
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* check context & fix signifiers
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6d178b4e
...@@ -247,7 +247,7 @@ if __name__ == "__main__": ...@@ -247,7 +247,7 @@ if __name__ == "__main__":
num_gemms_list = [8] num_gemms_list = [8]
if args.profile: if args.profile:
mkns = [(4096, 4096, 4096)] mkns = [(4096 * 8, 4096, 4096)]
# in profile mode, only run one recipe specified in args.recipe # in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", ( assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as" "In profile mode, only one recipe can be specified, please specify the recipe as"
......
...@@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems, const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) { const uint32_t offset_elems, const size_t type_num_bits) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API // Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
......
...@@ -30,6 +30,20 @@ extern "C" { ...@@ -30,6 +30,20 @@ extern "C" {
*/ */
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
#include "../util/logging.h" #include "../util/logging.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace { namespace {
constexpr int TB_DIM = 32; constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32;
constexpr int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int TB_DIM = 32;
constexpr int N_SF_PER_TD_PER_TILE = 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks // output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
...@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { ...@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M,
const int K) { const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
...@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
int m_tiles_in_tb = N_TILE_PER_TD; int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM; int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) { if (bid_x == grid_dim_x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
} }
if (blockIdx.y == gridDim.y - 1) { if (bid_y == grid_dim_y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
} }
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
const int input_offset =
bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + input_offset;
int32_t* output_i32[N_TILE_PER_TD]; int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll #pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) { for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + output_i32[i] = reinterpret_cast<int32_t*>(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
} }
extern __shared__ int slm[]; extern __shared__ int slm[];
...@@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
threadIdx.y < k_tiles_in_tb) { threadIdx.y < k_tiles_in_tb) {
#pragma unroll #pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>( const int thread_offset =
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD;
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
// Pad zeros
if (padding_m || padding_k) {
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
if (index / M >= original_K || index % M >= original_M) {
reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
}
}
}
} }
// local shuffle // local shuffle
...@@ -126,6 +144,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -126,6 +144,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
} }
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K) {
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) { __device__ inline void regs_shuffle(LType* regs_vec) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
...@@ -143,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { ...@@ -143,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) {
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M,
const int K) { const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
...@@ -154,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -154,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
int n_tiles_in_tb = N_TILES_IN_TB; int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4; const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) { if (bid_x == grid_dim_x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
} }
const int* input_i32 = reinterpret_cast<const int*>(input) + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB;
const int* input_i32 = reinterpret_cast<const int*>(input) + input_offset;
int* output_i32 = reinterpret_cast<int*>(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 +
bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[]; extern __shared__ int4 slm_v4i[];
...@@ -170,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -170,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll #pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>( const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD;
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
if (padding_m || padding_k) {
// Pad zeros
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
if (index / K >= original_M || index % K >= original_K) {
reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
}
}
}
} }
// shuffle regs // shuffle regs
...@@ -196,9 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -196,9 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
} }
} }
} // namespace template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K) {
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
namespace transformer_engine { constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiSwizzleArgs {
// (input) Data buffers for input scaling factors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for swizzled scaling factors
void* output_list[kMaxTensorsPerKernel];
// Input scaling factor m
int m_list[kMaxTensorsPerKernel];
// Input scaling factor k
int k_list[kMaxTensorsPerKernel];
// Input scaling factor m before padding
int original_m_list[kMaxTensorsPerKernel];
// Input scaling factor k before padding
int original_k_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) {
// Find tensor corresponding to block
const int bid = blockIdx.x;
int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB);
const int grid_dim_y = num_tiles_m;
const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) {
// Find tensor corresponding to block
const int bid = blockIdx.x;
int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM);
const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD);
const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
}
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
...@@ -252,27 +383,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -252,27 +383,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
...@@ -285,27 +418,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -285,27 +418,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
...@@ -317,10 +455,212 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -317,10 +455,212 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
} else { } else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
} }
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { NVTE_CHECK_CUDA(cudaGetLastError());
printf("CUDA Error: %s\n", cudaGetErrorString(err)); }
exit(-1);
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
const int vec_load_size, const bool is_rowwise,
cudaStream_t stream) {
int n_tiles_in_tb = TB_DIM * vec_load_size;
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
/* Calculate number of CUDA blocks needed for each tensor.
* We have to do it here because we have to iterate over all tensors in this batch to
* get the minimum vec_load_size.
*/
for (size_t j = 0; j < kernel_args.num_tensors; j++) {
const int m = kernel_args.m_list[j];
const int k = kernel_args.k_list[j];
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
if (is_rowwise) {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m;
} else {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] +
DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size);
}
}
// Launch kernel
const int num_blocks = kernel_args.block_range[kernel_args.num_tensors];
dim3 block_size(TB_DIM, TB_DIM);
if (is_rowwise) {
switch (vec_load_size) {
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
} else {
switch (vec_load_size) {
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size();
bool all_has_data = true;
bool all_has_columnwise_data = true;
for (size_t i = 0; i < num_tensors; i++) {
if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + ".");
}
// We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data();
}
NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data.");
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
if (all_has_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->scale_inv.shape[0];
const int k = input[i]->scale_inv.shape[1];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
}
if (all_has_columnwise_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->columnwise_scale_inv.shape[1];
const int k = input[i]->columnwise_scale_inv.shape[0];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
...@@ -335,3 +675,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud ...@@ -335,3 +675,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
using namespace transformer_engine; using namespace transformer_engine;
swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
} }
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
std::vector<Tensor*> input_list, output_list;
for (size_t i = 0; i < num_tensors; i++) {
input_list.push_back(convertNVTETensorCheck(inputs[i]));
output_list.push_back(convertNVTETensorCheck(outputs[i]));
}
multi_tensor_swizzle_scaling_factors(input_list, output_list, stream);
}
...@@ -35,6 +35,7 @@ struct MultiPaddingArgs { ...@@ -35,6 +35,7 @@ struct MultiPaddingArgs {
int padded_num_rows_list[kMaxTensorsPerKernel]; int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths // Input matrix widths
int row_length_list[kMaxTensorsPerKernel]; int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor // tensor
int block_range[kMaxTensorsPerKernel + 1]; int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel // Number of tensors being processed by kernel
......
...@@ -398,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -398,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
} }
// Allocate full buffer // Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>( auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views // Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
...@@ -441,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -441,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
} }
// Allocate full buffer // Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>( auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views // Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
......
...@@ -326,10 +326,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -326,10 +326,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector; te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers; std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors; std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto none = py::none(); auto none = py::none();
...@@ -396,10 +394,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -396,10 +394,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue; continue;
} }
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
auto te_D = makeTransformerEngineTensor(out_tensor); auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
...@@ -419,18 +413,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -419,18 +413,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_bias_vector.emplace_back(te_bias.data()); te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
wrappers.emplace_back(std::move(te_A)); te_A_wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B)); te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out)); wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa);
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb);
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data()); te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp)); wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
......
...@@ -841,13 +841,13 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -841,13 +841,13 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end()); rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
} }
if (columnwise_usage) { if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end()); columnwise_scale_inv_shape.end());
columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
} }
// Convert tensors to Python // Convert tensors to Python
...@@ -939,7 +939,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -939,7 +939,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(), const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end()); scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
} }
} else { // rowwise_usage == false } else { // rowwise_usage == false
...@@ -966,7 +966,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -966,7 +966,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(), const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end()); scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
} }
} else { // columnwise_usage == false } else { // columnwise_usage == false
......
...@@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
return swizzled_scale_inv; return swizzled_scale_inv;
} }
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return std::nullopt;
}
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape());
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(),
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
/* Swizzle the scaling factor of the input tensor. /*! \brief Swizzle the scaling factor of the input tensor.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/ */
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input, std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans); bool rowwise);
/*! \brief Swizzle the scaling factor of the input tensors.
*
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs.
*/
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment