Unverified Commit 1fc69f59 authored by Asaf Gardin's avatar Asaf Gardin Committed by GitHub
Browse files

[Bug fix][Quantization] Fix dummy weight loading (#38478)


Signed-off-by: default avatarJosephasafg <ajgard7@gmail.com>
parent d9c7db18
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
...@@ -26,12 +23,6 @@ class DummyModelLoader(BaseModelLoader): ...@@ -26,12 +23,6 @@ class DummyModelLoader(BaseModelLoader):
pass # Nothing to download pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# materialize meta tensors as part of online quantization lifecycle
for layer in model.modules():
for name, param in get_layer_tensors(layer).items():
if param.device == torch.device("meta"):
setattr(layer, name, materialize_meta_tensor(param))
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model, model_config) initialize_dummy_weights(model, model_config)
...@@ -11,7 +11,10 @@ from vllm.config import ModelConfig ...@@ -11,7 +11,10 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
initialize_single_dummy_weight,
)
from .meta import ( from .meta import (
capture_layer_to_meta, capture_layer_to_meta,
...@@ -223,7 +226,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon ...@@ -223,7 +226,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif info.load_numel <= 0: elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load # first load but received no weights. This happens on dummy load
if info.kernel_tensors is None: if info.kernel_tensors is None:
materialize_layer(layer, info) _layerwise_process(layer, info)
continue
# reloading: place kernel tensors back as a fallback # reloading: place kernel tensors back as a fallback
else: else:
...@@ -258,6 +262,12 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): ...@@ -258,6 +262,12 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device # Materialize layer tensors onto device
materialize_layer(layer, info) materialize_layer(layer, info)
# If no weights were loaded (e.g. dummy loading), initialize with
# small random values to avoid NaN from zero/garbage data
if len(info.loaded_weights) <= 0:
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
# Reset online quantization flag so process_weights_after_loading # Reset online quantization flag so process_weights_after_loading
# will run again during reload # will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"): if hasattr(layer, "_already_called_process_weights_after_loading"):
......
...@@ -239,10 +239,12 @@ def convert_bin_to_safetensor_file( ...@@ -239,10 +239,12 @@ def convert_bin_to_safetensor_file(
sf_size = os.stat(sf_filename).st_size sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01: if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%: raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size} - {sf_filename}: {sf_size}
- {pt_filename}: {pt_size} - {pt_filename}: {pt_size}
""") """
)
# check if the tensors are the same # check if the tensors are the same
reloaded = load_file(sf_filename) reloaded = load_file(sf_filename)
...@@ -1377,6 +1379,9 @@ def initialize_single_dummy_weight( ...@@ -1377,6 +1379,9 @@ def initialize_single_dummy_weight(
high: float = 1e-3, high: float = 1e-3,
seed: int = 1234, seed: int = 1234,
) -> None: ) -> None:
if param.device.type == "meta":
return # deferred to finalize_layerwise_processing (e.g. online quant)
if not torch.is_floating_point(param): if not torch.is_floating_point(param):
if current_platform.is_rocm(): if current_platform.is_rocm():
# On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left # On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment