"vscode:/vscode.git/clone" did not exist on "7640a8d407225b9e416563170bb668cc93f98424"
Unverified Commit 12065ac2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core] Add launch bounds to swizzle kernels (#2076)



Add launch bounds to swizzle kernel, use empty scale inv
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a169e9e7
...@@ -145,9 +145,9 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, ...@@ -145,9 +145,9 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output,
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, __global__ void __launch_bounds__(TB_DIM* TB_DIM)
const int K, const int original_M, swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_K) { const int original_M, const int original_K) {
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>( swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
} }
...@@ -238,9 +238,9 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, ...@@ -238,9 +238,9 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output,
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, __global__ void __launch_bounds__(TB_DIM* TB_DIM)
const int K, const int original_M, swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_K) { const int original_M, const int original_K) {
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>( swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
} }
......
...@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer):
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.zeros( scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty_like(data) columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.zeros( columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128), round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8, dtype=torch.uint8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment