Unverified Commit 8bfaa4e3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)

parent 0b5b5d76
...@@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
...@@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData ...@@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
...@@ -687,35 +686,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -687,35 +686,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load query tokens
for name, loaded_weight in weights_group["query_tokens"]:
assert name == ""
param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load qformer
qformer_params_dict = dict(self.qformer.named_parameters())
for name, loaded_weight in weights_group["qformer"]:
param = qformer_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.language_projection.named_parameters())
for name, loaded_weight in weights_group["language_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -42,8 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, ...@@ -42,8 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -349,16 +347,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -349,16 +347,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (group_weights_with_prefix, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -447,19 +447,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -447,19 +447,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(
self,
self.model.load_weights(weights_group["model"]) skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
if not self.config.tie_word_embeddings: )
# NOTE: For now self.lm_head is not defined because loader.load_weights(weights)
# tie_word_embeddings is assumed to the False
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -32,8 +31,8 @@ from vllm.utils import is_list_of ...@@ -32,8 +31,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -609,19 +608,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -609,19 +608,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -51,8 +51,7 @@ from vllm.sequence import IntermediateTensors ...@@ -51,8 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -564,25 +563,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -564,25 +563,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = [ loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight) self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights for name, loaded_weight in weights)
]
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
if not self.config.tie_word_embeddings:
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)
......
...@@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal, SupportsPP ...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -406,19 +405,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -406,19 +405,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -29,8 +28,8 @@ from .llava import LlavaMultiModalProjector ...@@ -29,8 +28,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -642,27 +641,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -642,27 +641,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip ...@@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# For profile run # For profile run
...@@ -458,19 +457,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -458,19 +457,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(
weights_group = group_weights_with_prefix(weights) self,
# This model doesn't support images for now
# load vision encoder ignore_unexpected_prefixes=["image_newline"],
self.vision_tower.load_weights(weights_group["vision_tower"]) )
loader.load_weights(weights)
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -20,7 +20,6 @@ from vllm.logger import init_logger ...@@ -20,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
...@@ -35,8 +34,8 @@ from .interfaces import SupportsMultiModal, SupportsPP ...@@ -35,8 +34,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -872,19 +871,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -872,19 +871,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -11,7 +11,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -11,7 +11,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors ...@@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import group_weights_with_prefix, merge_multimodal_embeddings from .utils import AutoWeightsLoader, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -292,19 +291,5 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -292,19 +291,5 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision tower
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -42,15 +41,11 @@ from vllm.utils import is_list_of ...@@ -42,15 +41,11 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044 _IMAGE_TOKEN_ID = 32044
...@@ -295,35 +290,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -295,35 +290,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.img_processor.load_weights(weights_group["img_processor"])
# load glb_GN
for name, loaded_weight in weights_group["glb_GN"]:
assert name == ""
param = self.glb_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load sub_GN
for name, loaded_weight in weights_group["sub_GN"]:
assert name == ""
param = self.sub_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.img_projection.named_parameters())
for name, loaded_weight in weights_group["img_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
...@@ -715,27 +683,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -715,27 +683,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapping = { hf_to_vllm_mapper = WeightsMapper(
"model.vision_embed_tokens.": "vision_embed_tokens.", orig_to_new_prefix={
"lm_head.": "language_model.lm_head.", "model.vision_embed_tokens.": "vision_embed_tokens.",
"model.": "language_model.model.", "lm_head.": "language_model.lm_head.",
} "model.": "language_model.model.",
})
def hf_to_vllm_name(key: str) -> str:
for hf_name, vllm_name in hf_to_vllm_mapping.items(): loader = AutoWeightsLoader(self)
if key.startswith(hf_name): loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return key.replace(hf_name, vllm_name, 1)
return key
vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
# prepare weight iterators for components
weights_group = group_weights_with_prefix(vllm_weights.items())
# load vision embeddings and encoder
self.vision_embed_tokens.load_weights(
weights_group["vision_embed_tokens"])
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -48,8 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -48,8 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -435,17 +434,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -435,17 +434,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(
self,
self.model.load_weights(weights_group["model"]) skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
if not self.config.tie_word_embeddings: )
lm_head_dict = dict(self.lm_head.named_parameters()) loader.load_weights(weights)
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -16,13 +16,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -16,13 +16,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import group_weights_with_prefix from .utils import AutoWeightsLoader
class ReLU(nn.Module): class ReLU(nn.Module):
...@@ -120,13 +119,5 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -120,13 +119,5 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(self)
loader.load_weights(weights)
self.model.load_weights(weights_group["model"])
score_dict = dict(self.score.named_parameters())
for name, loaded_weight in weights_group["score"]:
param = score_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -25,11 +25,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -25,11 +25,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.base import MultiModalInputs, NestedTensors
...@@ -41,6 +36,8 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig ...@@ -41,6 +36,8 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings)
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
...@@ -498,30 +495,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -498,30 +495,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components hf_to_vllm_mapper = WeightsMapper(
weights_group = group_weights_with_prefix(weights) orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
# load audio tower weights loader = AutoWeightsLoader(self,
audio_tower_weights = weights_group["audio_tower"] ignore_unexpected_prefixes=["audio_tower."])
audio_tower_params_dict = dict( loader.load_weights(weights, mapper=hf_to_vllm_mapper)
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
param = projector_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
import itertools import itertools
from collections import UserDict from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, Union, overload) Protocol, Tuple, Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, ...@@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: @dataclass
try: class WeightsMapper:
return super().__getitem__(key) """Maps the name of each weight if they match the following patterns."""
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def _map_name(self, key: str) -> Optional[str]:
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: for substr, new_key in self.orig_to_new_substr.items():
""" if substr in key:
Helper function to load weights for inner vLLM models. if new_key is None:
return None
See also: key = key.replace(substr, new_key, 1)
:ref:`init_vllm_registered_model`
""" for prefix, new_key in self.orig_to_new_prefix.items():
for name, loaded_weight in weights: if key.startswith(prefix):
name = name.split(".") if new_key is None:
if prefix == name.pop(0): return None
name = ".".join(name)
yield name, loaded_weight key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def group_weights_with_prefix( def apply(
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
class AutoWeightsLoader:
""" """
Helper function to group weights with prefix Helper class to load weights into a :class:`torch.nn.Module`. It is able
to automatically detect child modules and parameters while iterating over
the weights only once.
The weight loading logic for individual modules can be overridden
by defining a ``load_weights`` method.
Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method.
""" """
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights} def __init__(
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) self,
module: nn.Module,
return WeightsGroup({ *,
prefix: filter_weights(component, prefix) skip_prefixes: Optional[List[str]] = None,
for component, prefix in zip(repeated_weights, weights_prefix) ignore_unexpected_prefixes: Optional[List[str]] = None,
}) ) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
def _groupby_prefix(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
for prefix, group in itertools.groupby(weights_by_parts,
key=lambda x: x[0][0]):
yield (
prefix,
# Because maxsplit=1 in weight_name.split(...),
# the length of `parts` must either be 1 or 2
(("" if len(parts) == 1 else parts[1], weights_data)
for parts, weights_data in group),
)
def _get_qualname(self, prefix: str, rest: str) -> str:
if prefix == "":
return rest
if rest == "":
return prefix
return ".".join((prefix, rest))
def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
def _load_param(
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname):
continue
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
def _load_module(
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
if isinstance(module, PPMissingLayer):
return
# Avoid infinite recursion since this function is typically
# called inside load_weights of the module itself
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'"
raise ValueError(msg)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> None:
if mapper is not None:
weights = mapper.apply(weights)
self._load_module("", self.module, weights)
def init_vllm_registered_model( def init_vllm_registered_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment