Unverified Commit 0278f1ac authored by Yi Liu's avatar Yi Liu Committed by GitHub
Browse files

Fix nvfp4 swizzling (#23140)


Signed-off-by: default avataryiliu30 <yi4.liu@intel.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent a482e4e7
...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations) run_nvfp4_emulations)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
...@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
weight_loader=weight_loader) weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale) layer.register_parameter("input_global_scale", input_global_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32) global_input_scale = layer.input_global_scale.max().to(torch.float32)
...@@ -137,7 +116,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -137,7 +116,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
requires_grad=False) requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False)
else: else:
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False) requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data, layer.weight_packed = Parameter(layer.weight_packed.data,
......
...@@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: ...@@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2: if scale_ndim == 2:
return swizzled.reshape(M, K) return swizzled.reshape(M_padded, K_padded)
return swizzled.reshape(B, M, K) return swizzled.reshape(B, M_padded, K_padded)
def cutlass_fp4_supported() -> bool: def cutlass_fp4_supported() -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment