Unverified Commit 4572a06a authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Enable weights loading tracking for quantized models (#35074)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 5cc29cfb
...@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME ...@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -286,7 +287,6 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -286,7 +287,6 @@ class DefaultModelLoader(BaseModelLoader):
): ):
self.load_config.safetensors_load_strategy = "torchao" self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter() self.counter_after_loading_weights = time.perf_counter()
...@@ -295,9 +295,20 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -295,9 +295,20 @@ class DefaultModelLoader(BaseModelLoader):
self.counter_after_loading_weights - self.counter_before_loading_weights, self.counter_after_loading_weights - self.counter_before_loading_weights,
scope="local", scope="local",
) )
# We only enable strict check for non-quantized models self.track_weights_loading(model, loaded_weights)
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None: def track_weights_loading(
self, model: nn.Module, loaded_weights: set[str] | None
) -> None:
weights_to_load = {name for name, _ in model.named_parameters()}
if loaded_weights is not None:
for name, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
# ignore kv_cache scale, which can be missing in checkpoints
if isinstance(quant_method, BaseKVCacheMethod):
for param_name, _ in module.named_parameters():
full_name = f"{name}.{param_name}" if name else param_name
loaded_weights.add(full_name)
weights_not_loaded = weights_to_load - loaded_weights weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded: if weights_not_loaded:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment