Unverified Commit 13d88d41 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Refactor composite weight loading logic (#8656)

parent d66ac628
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import itertools
import re import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -33,8 +32,8 @@ from vllm.utils import is_list_of ...@@ -33,8 +32,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -518,21 +517,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -518,21 +517,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) weights_group = group_weights_with_prefix(weights)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model") self.vision_model.load_weights(weights_group["vision_model"])
self.vision_model.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters()) mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal ...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) weights_group = group_weights_with_prefix(weights)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -30,8 +29,8 @@ from .llava import LlavaMultiModalProjector ...@@ -30,8 +29,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -637,25 +636,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -637,25 +636,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( weights_group = group_weights_with_prefix(weights)
weights, 4)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load newline # load newline
newline_weights = filter_weights(newline_weights, "image_newline") for name, loaded_weight in weights_group["image_newline"]:
for name, loaded_weight in newline_weights:
assert name == "" assert name == ""
param = self.image_newline param = self.image_newline
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
...@@ -663,5 +658,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -663,5 +658,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -30,7 +29,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip ...@@ -30,7 +29,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -449,23 +448,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -449,23 +448,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators # prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( weights_group = group_weights_with_prefix(weights)
weights, 4)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors ...@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import filter_weights, merge_multimodal_embeddings from .utils import group_weights_with_prefix, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -286,21 +285,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -286,21 +285,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) weights_group = group_weights_with_prefix(weights)
# load vision tower # load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import itertools
import math import math
from array import array from array import array
from functools import lru_cache from functools import lru_cache
...@@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn, from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2) weights_group = group_weights_with_prefix(weights)
# load projector weights # load projector weights
projector_weights = filter_weights(projector_weights, projector_weights = weights_group["multi_modal_projector"]
"multi_modal_projector")
projector_params_dict = dict( projector_params_dict = dict(
self.multi_modal_projector.named_parameters()) self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights: for name, loaded_weight in projector_weights:
...@@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload) Union, overload)
...@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors ...@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
""" """
Helper function to load weights for inner vLLM models. Helper function to load weights for inner vLLM models.
...@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): ...@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight yield name, loaded_weight
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def init_vllm_registered_model( def init_vllm_registered_model(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig], cache_config: Optional[CacheConfig],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment