Unverified Commit 13370712 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Replace embedding models with pooling adapter (#10769)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7e4bbda5
...@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=maybe_prefix(prefix, "vision_tower")) prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
......
...@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config self.quant_config = quant_config
config.text_config.architectures = ["GemmaForCausalLM"] config.text_config.architectures = ["GemmaForCausalLM"]
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale self.language_model.logits_processor.scale *= logit_scale
......
...@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig ...@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config, quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
# The prefix is empty intentionally because default prefix of self.language_model = init_vllm_registered_model(
# LlamaForCausalLM is "model" vllm_config=vllm_config,
self.language_model = LlamaForCausalLM(vllm_config=vllm_config, # The prefix is empty intentionally because default prefix of
prefix="") # LlamaForCausalLM is "model"
prefix="",
# The same model class supports both language generation and embedding # We don't directly initialize vLLM's LlamaForCausalLM so we
# because the architecture name is the same # can automatically apply embedding wrapper if this model is
self._pooler = Pooler.from_config_with_defaults( # initialized as an embedding model
pooler_config, architectures=["LlamaForCausalLM"],
pooling_type=PoolingType.LAST, )
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
...@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
......
...@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM # init MistralForCausalLM
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_encoder = VisionTransformer(self.vision_args) self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter( self.vision_language_adapter = VisionLanguageAdapter(
......
...@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType ...@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
logger = init_logger(__name__)
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
...@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
...@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
...@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config, self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
if pooler_config.pooling_type is None:
logger.warning(
"This embedding model will default to last-token pooling in "
"an upcoming version. To avoid breaking changes, you should "
"pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`"
" explicitly.")
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.MEAN, pooling_type=PoolingType.MEAN,
......
...@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU ...@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
...@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs, NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
...@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching" "Qwen2-VL currently does not support prefix caching"
...@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
...@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -20,6 +20,7 @@ import torch.nn as nn ...@@ -20,6 +20,7 @@ import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal, supports_cross_encoding, supports_multimodal,
supports_pp) supports_pp)
...@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = { ...@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"), "LlamaModel": ("llama", "LlamaForCausalLM"),
**{ **{
# Multiple models share the same architecture, so we include them all # Multiple models share the same architecture, so we include them all
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
if arch == "LlamaForCausalLM" if arch == "LlamaForCausalLM"
}, },
"MistralModel": ("llama", "LlamaEmbeddingModel"), "MistralModel": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
...@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = { ...@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
} }
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
...@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { ...@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@dataclass(frozen=True) @dataclass(frozen=True)
class _ModelInfo: class _ModelInfo:
architecture: str
is_text_generation_model: bool is_text_generation_model: bool
is_embedding_model: bool is_embedding_model: bool
supports_cross_encoding: bool supports_cross_encoding: bool
...@@ -218,9 +220,19 @@ class _ModelInfo: ...@@ -218,9 +220,19 @@ class _ModelInfo:
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
return _ModelInfo( return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model), is_embedding_model=is_embedding_model_,
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
...@@ -399,13 +411,13 @@ class _ModelRegistry: ...@@ -399,13 +411,13 @@ class _ModelRegistry:
def inspect_model_cls( def inspect_model_cls(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> _ModelInfo: ) -> Tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures) architectures = self._normalize_archs(architectures)
for arch in architectures: for arch in architectures:
model_info = self._try_inspect_model_cls(arch) model_info = self._try_inspect_model_cls(arch)
if model_info is not None: if model_info is not None:
return model_info return (model_info, arch)
return self._raise_for_unsupported(architectures) return self._raise_for_unsupported(architectures)
...@@ -426,39 +438,50 @@ class _ModelRegistry: ...@@ -426,39 +438,50 @@ class _ModelRegistry:
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model( def is_embedding_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
def is_cross_encoder_model( def is_cross_encoder_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
return self.inspect_model_cls(architectures).supports_cross_encoding model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding
def is_multimodal_model( def is_multimodal_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def is_pp_supported_model( def is_pp_supported_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
return self.inspect_model_cls(architectures).supports_pp model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp
def model_has_inner_state(self, architectures: Union[str, def model_has_inner_state(
List[str]]) -> bool: self,
return self.inspect_model_cls(architectures).has_inner_state architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state
def is_attention_free_model(self, architectures: Union[str, def is_attention_free_model(
List[str]]) -> bool: self,
return self.inspect_model_cls(architectures).is_attention_free architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free
ModelRegistry = _ModelRegistry({ ModelRegistry = _ModelRegistry({
......
...@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
)) ))
self.multi_modal_projector = UltravoxProjector(config) self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.text_model_id is not None: if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights # this prefix is not for initialization, but for loading weights
# note the trailing dot # note the trailing dot
......
...@@ -173,8 +173,15 @@ class AutoWeightsLoader: ...@@ -173,8 +173,15 @@ class AutoWeightsLoader:
module_load_weights = getattr(module, "load_weights", None) module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights): if callable(module_load_weights):
loaded_params = module_load_weights(weights) loaded_params = module_load_weights(weights)
yield from map(lambda x: self._get_qualname(base_prefix, x), if loaded_params is None:
loaded_params) logger.warning(
"Unable to collect loaded parameters "
"for module %s", module)
else:
yield from map(
lambda x: self._get_qualname(base_prefix, x),
loaded_params,
)
child_modules = dict(module.named_children()) child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False)) child_params = dict(module.named_parameters(recurse=False))
...@@ -232,17 +239,24 @@ class AutoWeightsLoader: ...@@ -232,17 +239,24 @@ class AutoWeightsLoader:
def init_vllm_registered_model( def init_vllm_registered_model(
hf_config: PretrainedConfig,
vllm_config: VllmConfig, vllm_config: VllmConfig,
*,
prefix: str = "", prefix: str = "",
hf_config: Optional[PretrainedConfig] = None,
architectures: Optional[list[str]] = None,
) -> nn.Module: ) -> nn.Module:
""" """
Helper function to initialize an inner model registered to vLLM, Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model. based on the arguments passed to the outer vLLM model.
""" """
from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.loader import _initialize_model
vllm_config = vllm_config.with_hf_config(hf_config)
return _initialize_model(vllm_config, prefix) if hf_config is not None:
vllm_config = vllm_config.with_hf_config(hf_config)
return _initialize_model(vllm_config=vllm_config,
prefix=prefix,
architectures=architectures)
@overload @overload
......
...@@ -7,7 +7,7 @@ from torch import nn ...@@ -7,7 +7,7 @@ from torch import nn
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides, from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC): ...@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
@abstractmethod @abstractmethod
def get_data_key(self) -> str: def get_data_key(self) -> str:
......
...@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias ...@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import ClassRegistry
from .audio import AudioPlugin from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
...@@ -62,8 +63,8 @@ class MultiModalRegistry: ...@@ -62,8 +63,8 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories: Dict[Type[nn.Module], self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory] = {} MultiModalProcessorFactory]()
# This is used for non-multimodal models # This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
......
...@@ -20,7 +20,7 @@ import uuid ...@@ -20,7 +20,7 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import defaultdict from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
...@@ -1517,13 +1517,13 @@ class AtomicCounter: ...@@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708 # Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]): class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]): def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory self._factory = factory
self._dict: Dict[str, T] = {} self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T: def __getitem__(self, key: str) -> T:
if key not in self._dict: if key not in self._dict:
if key not in self._factory: if key not in self._factory:
raise KeyError(key) raise KeyError(key)
...@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]): ...@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory) return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
if not isinstance(key, type):
return False
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
""" """
Create a weak reference to a tensor. Create a weak reference to a tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment