Unverified Commit 3dc01ef3 authored by Asaf Gardin's avatar Asaf Gardin Committed by GitHub
Browse files

[Quantization] Consolidate dummy format logic into DummyModelLoader (#38637)


Signed-off-by: default avatarJosephasafg <ajgard7@gmail.com>
parent cc671cb1
...@@ -4,8 +4,19 @@ import torch.nn as nn ...@@ -4,8 +4,19 @@ import torch.nn as nn
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights from vllm.model_executor.model_loader.reload.layerwise import (
_get_original_loader,
get_layerwise_info,
)
from vllm.model_executor.model_loader.reload.meta import materialize_layer
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.model_executor.model_loader.weight_utils import (
initialize_dummy_weights,
initialize_single_dummy_weight,
)
class DummyModelLoader(BaseModelLoader): class DummyModelLoader(BaseModelLoader):
...@@ -23,6 +34,31 @@ class DummyModelLoader(BaseModelLoader): ...@@ -23,6 +34,31 @@ class DummyModelLoader(BaseModelLoader):
pass # Nothing to download pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign for layer in model.modules():
# random values to the weights. info = get_layerwise_info(layer)
initialize_dummy_weights(model, model_config) if info.can_load():
self._process_online_quant_layer(layer, info)
else:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(layer, model_config)
def _process_online_quant_layer(
self,
layer: nn.Module,
info: LayerReloadingInfo,
) -> None:
"""Materialize, apply dummy weights, and run quantization processing."""
materialize_layer(layer, info)
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
info.reset()
...@@ -11,10 +11,7 @@ from vllm.config import ModelConfig ...@@ -11,10 +11,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
default_weight_loader,
initialize_single_dummy_weight,
)
from .meta import ( from .meta import (
capture_layer_to_meta, capture_layer_to_meta,
...@@ -224,7 +221,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon ...@@ -224,7 +221,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
# No weights were loaded # No weights were loaded
elif info.load_numel <= 0: elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load # first load: checkpoint did not contain weights for this layer
if info.kernel_tensors is None: if info.kernel_tensors is None:
_layerwise_process(layer, info) _layerwise_process(layer, info)
continue continue
...@@ -262,12 +259,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): ...@@ -262,12 +259,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device # Materialize layer tensors onto device
materialize_layer(layer, info) materialize_layer(layer, info)
# If no weights were loaded (e.g. dummy loading), initialize with
# small random values to avoid NaN from zero/garbage data
if len(info.loaded_weights) <= 0:
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
# Reset online quantization flag so process_weights_after_loading # Reset online quantization flag so process_weights_after_loading
# will run again during reload # will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"): if hasattr(layer, "_already_called_process_weights_after_loading"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment