Unverified Commit 5869f69c authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Online Quant] [QeRL] Minor code cleanup (#38574)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 4dfad17e
......@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod,
Fp8OnlineLinearMethod,
Fp8OnlineMoEMethod,
_copy_missing_attrs,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
......@@ -43,11 +42,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.model_loader.weight_utils import (
initialize_single_dummy_weight,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -183,17 +178,6 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
......@@ -265,28 +249,6 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
**extra_weight_attrs,
)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]
def _quantize_mxfp8_moe_weight(
......@@ -309,34 +271,9 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
layer.w13_input_scale = None
layer.w2_input_scale = None
......
......@@ -230,7 +230,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
continue
# reloading: place kernel tensors back as a fallback
else:
elif info.load_numel_total > 0: # type: ignore[operator]
logger.warning("%s: Failed to load weights", layer.__class__.__name__)
_place_kernel_tensors(layer, info)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment