Unverified Commit 3f06bae9 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][Model] Support loading weights by ID within models (#7931)

parent b8747e8a
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import copy import copy
import dataclasses
import fnmatch import fnmatch
import glob import glob
import json import json
...@@ -8,7 +9,8 @@ import math ...@@ -8,7 +9,8 @@ import math
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
Type, cast)
import gguf import gguf
import huggingface_hub import huggingface_hub
...@@ -207,6 +209,22 @@ class BaseModelLoader(ABC): ...@@ -207,6 +209,22 @@ class BaseModelLoader(ABC):
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk.""" """Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config: if load_config.model_loader_extra_config:
...@@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader):
return hf_folder, hf_weights_files, use_safetensors return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator( def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str], self, source: "Source"
fall_back_to_pt: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt) source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE: if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
weights_iterator = np_cache_weights_iterator( weights_iterator = np_cache_weights_iterator(
model_name_or_path, self.load_config.download_dir, hf_folder, source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files) hf_weights_files)
elif use_safetensors: elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(hf_weights_files)
...@@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader):
xm.mark_step() xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, self._prepare_weights(model_config.model,
...@@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, cache_config, lora_config, cache_config,
scheduler_config) scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model, model.load_weights(self._get_all_weights(model_config, model))
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
......
...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn, from vllm.model_executor.models.utils import (flatten_bn,
...@@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self.multi_modal_config = multimodal_config self.multi_modal_config = multimodal_config
assert self.multi_modal_config assert self.multi_modal_config
self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None: if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained( self.secondary_weights.append(
config.audio_model_id) DefaultModelLoader.Source(
else: model_or_path=config.audio_model_id,
self.audio_tower = ModifiedWhisperEncoder(config.audio_config) revision=None,
prefix="audio_tower.",
))
self.multi_modal_projector = UltravoxProjector(config) self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,
prefix="language_model."))
def _audio_features_to_embeddings( def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor: self, input_features: torch.Tensor) -> torch.Tensor:
...@@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
# prepare weight iterators for components # prepare weight iterators for components
weights_group = group_weights_with_prefix(weights) weights_group = group_weights_with_prefix(weights)
# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights # load projector weights
projector_weights = weights_group["multi_modal_projector"] projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict( projector_params_dict = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment