Unverified Commit 12065ac2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core] Add launch bounds to swizzle kernels (#2076)



Add launch bounds to swizzle kernel, use empty scale inv
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a169e9e7
......@@ -145,9 +145,9 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output,
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K) {
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_M, const int original_K) {
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
......@@ -238,9 +238,9 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output,
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K) {
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_M, const int original_K) {
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
......
......@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer):
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.zeros(
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
......@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.zeros(
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment