Unverified Commit 36bf8150 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model][VLM] Decouple weight loading logic for `Paligemma` (#8269)

parent e8071259
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
...@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors ...@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_multimodal_embeddings from .utils import filter_weights, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model",
}
class PaliGemmaImagePixelInputs(TypedDict): class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
...@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = GemmaModel(config.text_config, cache_config, self.language_model = GemmaForCausalLM(config.text_config,
quant_config) cache_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
...@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
vision_embeddings = vision_embeddings * (self.config.hidden_size** vision_embeddings = vision_embeddings * (self.config.hidden_size**
-0.5) -0.5)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
...@@ -262,7 +260,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -262,7 +260,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
...@@ -271,78 +269,38 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -271,78 +269,38 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
return hidden_states return hidden_states
# Copied from vllm/model_executor/models/gemma.py
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens, return self.language_model.compute_logits(hidden_states,
hidden_states, sampling_metadata) sampling_metadata)
return logits
# Copied from vllm/model_executor/models/gemma.py
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
# Adapted from vllm/model_executor/models/gemma.py
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ # prepare weight iterators for components
# (param_name, shard_name, shard_id) vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), # load vision tower
("qkv_proj", "v_proj", "v"), vit_weights = filter_weights(vit_weights, "vision_tower")
("gate_up_proj", "gate_proj", 0), self.vision_tower.load_weights(vit_weights)
("gate_up_proj", "up_proj", 1),
] # load mlp projector
params_dict = dict(self.named_parameters()) mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
loaded_params = set() mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in mlp_weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): param = mlp_params_dict[name]
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" not in name or self.vision_tower.shard_weight:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
unloaded_params = params_dict.keys() - loaded_params self.language_model.load_weights(llm_weights)
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
...@@ -529,6 +529,12 @@ class SiglipVisionModel(nn.Module): ...@@ -529,6 +529,12 @@ class SiglipVisionModel(nn.Module):
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
...@@ -544,6 +550,15 @@ class SiglipVisionModel(nn.Module): ...@@ -544,6 +550,15 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment