Unverified Commit 4753f3bf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use context managers for encoder- and LM-only mode (#32605)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6c01ffb8
......@@ -38,7 +38,7 @@ Encoder engines should be launched with the following flags:
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
- `--convert "mm_encoder_only"` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must implement the `get_language_model_spec` interface.**
- `--mm-encoder-only` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must initialize the language component inside the context of `SupportsMultiModal._mark_language_model`.**
## Local media inputs
......
......@@ -306,6 +306,7 @@ class ModelConfig:
mm_processor_cache_gb: InitVar[float | None] = None
mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_only: InitVar[bool | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None
......@@ -420,6 +421,7 @@ class ModelConfig:
mm_processor_cache_gb: float | None,
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_only: bool | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None,
......@@ -497,6 +499,15 @@ class ModelConfig:
)
self.model_arch_config = self.get_model_arch_config()
if self.convert == "mm_encoder_only":
logger.warning_once(
"`--convert mm_encoder_only` is deprecated and "
"will be removed in v0.15. "
"Please use --mm-encoder-only` instead."
)
mm_encoder_only = True
self.convert = "none"
architectures = self.architectures
registry = self.registry
is_generative_model = registry.is_text_generation_model(architectures, self)
......@@ -583,6 +594,7 @@ class ModelConfig:
mm_processor_cache_gb=mm_processor_cache_gb,
mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_encoder_only=mm_encoder_only,
mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend,
interleave_mm_strings=interleave_mm_strings,
......
......@@ -108,6 +108,12 @@ class MultiModalConfig:
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`."""
mm_encoder_only: bool = False
"""
When enabled, skips the language component of the model.
This is usually only valid in disaggregated Encoder process.
"""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using tensor
parallelism (TP).
......
......@@ -467,6 +467,7 @@ class EngineArgs:
mm_shm_cache_max_object_size_mb: int = (
MultiModalConfig.mm_shm_cache_max_object_size_mb
)
mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
MultiModalConfig.mm_encoder_attn_backend
......@@ -973,6 +974,9 @@ class EngineArgs:
"--mm-shm-cache-max-object-size-mb",
**multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
)
multimodal_group.add_argument(
"--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
)
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
)
......@@ -1256,6 +1260,7 @@ class EngineArgs:
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
mm_encoder_only=self.mm_encoder_only,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config,
......
......@@ -189,9 +189,7 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
)
convert_type = model_config.convert_type
if convert_type not in ["none", "mm_encoder_only"] and supports_multimodal(
model_cls
):
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
......@@ -202,11 +200,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
if convert_type == "none":
pass
elif convert_type == "mm_encoder_only":
logger.debug_once("Converting to mm encoder only model.")
from vllm.model_executor.models.adapters import as_mm_encoder_only_model
model_cls = as_mm_encoder_only_model(model_cls)
elif convert_type == "embed":
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
......
......@@ -529,64 +529,3 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
method = getattr(hf_config, "method", getattr(text_config, "method", None))
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)
def as_mm_encoder_only_model(cls: _T) -> _T:
"""
Subclass an existing vLLM vl model to support mm encoder only for
EPD encoder instances.
"""
if not hasattr(cls, "embed_multimodal"):
# Submodel case: return the original class.
return cls
if not hasattr(cls, "get_language_model_spec"):
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
lm_model_cls, lm_attr = cls.get_language_model_spec()
if lm_model_cls is None or lm_attr is None:
raise TypeError(
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
)
class DummyLM(nn.Module):
def __init__(self, *args, **kwargs):
self.make_empty_intermediate_tensors = None
class ModelForMMEncoderOnly(cls):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
self.is_mm_encoder_only_model = True
origin_init = lm_model_cls.__init__
try:
lm_model_cls.__init__ = DummyLM.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
if hasattr(self, lm_attr):
delattr(self, lm_attr)
finally:
lm_model_cls.__init__ = origin_init
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
from .utils import AutoWeightsLoader
origin_init_ = AutoWeightsLoader.__init__
def _new_init_(self, *args, **kwargs):
origin_init_(self, *args, **kwargs)
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
try:
AutoWeightsLoader.__init__ = _new_init_
result = super().load_weights(weights)
finally:
AutoWeightsLoader.__init__ = origin_init_
return result
return ModelForMMEncoderOnly # type: ignore
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
ClassVar,
......@@ -69,6 +70,46 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
return is_multimodal
class LMMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
def __call__(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str]) -> None:
super().__init__()
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(f"The following modalities are disabled: {self.modalities}")
@contextmanager
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
"""
def callback(module_, name, submodule):
if module_ is module:
return placeholder()
return submodule
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with torch.device("meta"):
yield
@runtime_checkable
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
......@@ -105,6 +146,16 @@ class SupportsMultiModal(Protocol):
Set internally by `MultiModalRegistry.register_processor`.
"""
_language_model_names: list[str] = []
"""
Set internally by `_mark_language_model`.
"""
_tower_model_names: list[str] = []
"""
Set internally by `_mark_tower_model`.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""
......@@ -134,7 +185,64 @@ class SupportsMultiModal(Protocol):
Returns:
torch.nn.Module: The core language model component.
"""
...
if self._language_model_names:
return getattr(self, self._language_model_names[0])
raise NotImplementedError(
f"No language model found in {type(self).__name__}! "
"You should initialize it inside `_mark_language_model`."
)
@contextmanager
def _mark_language_model(self, vllm_config: VllmConfig):
"""
Mark each child module that was assigned to this model
during this context as a language model component.
"""
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, LMMissingLayer)
if mm_config.mm_encoder_only
else nullcontext()
):
yield
self._language_model_names = children_names
@contextmanager
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str):
"""
Mark each child module that was assigned to this model
during this context as a tower model component.
"""
if isinstance(modalities, str):
modalities = {modalities}
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, lambda: TowerMissingLayer(modalities))
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
else nullcontext()
):
yield
self._tower_model_names = children_names
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
......@@ -154,14 +262,6 @@ class SupportsMultiModal(Protocol):
"""
...
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return None, None
@overload
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
......@@ -299,10 +399,6 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
def supports_mm_encoder_only(model: type[object] | object) -> bool:
return getattr(model, "is_mm_encoder_only_model", False)
@overload
def supports_multimodal_pruning(
model: type[object],
......
......@@ -550,8 +550,7 @@ class LlavaForConditionalGeneration(
):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
......@@ -567,10 +566,8 @@ class LlavaForConditionalGeneration(
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_tower = None
self.multi_modal_projector = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......@@ -631,8 +628,6 @@ class LlavaForConditionalGeneration(
self,
inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
......@@ -644,7 +639,6 @@ class LlavaForConditionalGeneration(
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor):
......@@ -656,9 +650,6 @@ class LlavaForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......@@ -727,11 +718,7 @@ class LlavaForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......
......@@ -457,8 +457,7 @@ class Mistral3ForConditionalGeneration(
):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
......@@ -476,10 +475,8 @@ class Mistral3ForConditionalGeneration(
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_tower = None
self.multi_modal_projector = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......@@ -534,9 +531,6 @@ class Mistral3ForConditionalGeneration(
image_embeds = (image_embeds,)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......@@ -607,11 +601,7 @@ class Mistral3ForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......
......@@ -70,12 +70,14 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
LMMissingLayer,
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
TowerMissingLayer,
)
from .llama4 import Llama4ForCausalLM
from .utils import (
......@@ -773,7 +775,8 @@ class Llama4ForConditionalGeneration(
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
from vllm.compilation.backends import set_model_tag
with (
......@@ -792,9 +795,8 @@ class Llama4ForConditionalGeneration(
quant_config=None,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_model = None
self.multi_modal_projector = None
with self._mark_language_model(vllm_config):
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(
config.text_config, ["LlamaForCausalLM"]
......@@ -892,9 +894,6 @@ class Llama4ForConditionalGeneration(
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......@@ -1024,6 +1023,10 @@ class Llama4ForConditionalGeneration(
for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name)
attr = renamed.split(".", 1)[0]
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
continue
if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight))
else:
......@@ -1133,10 +1136,6 @@ class Llama4ForConditionalGeneration(
weights
)
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
other_weights = []
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
......
......@@ -239,7 +239,7 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config.is_multimodal_pruning_enabled()
)
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......@@ -247,9 +247,8 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......
......@@ -398,13 +398,14 @@ class PixtralForConditionalGeneration(
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_encoder = VisionTransformer(self.vision_args)
self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
......@@ -423,11 +424,6 @@ class PixtralForConditionalGeneration(
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size
)
else:
self.vision_encoder = None
self.pre_mm_projector_norm = None
self.patch_merger = None
self.vision_language_adapter = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
......@@ -449,10 +445,6 @@ class PixtralForConditionalGeneration(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
assert (
self.vision_encoder is not None and self.vision_language_adapter is not None
)
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
......@@ -477,9 +469,6 @@ class PixtralForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -822,6 +822,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
# force "use_flash_attention_2=True" to audio tower to align
# the results.
......@@ -836,14 +837,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"in the audio tower part."
)
if multimodal_config.get_limit_per_prompt("audio"):
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
else:
self.audio_tower = None
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
......@@ -851,10 +848,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None
self.quant_config = quant_config
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
......@@ -895,9 +890,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mrope_input_positions(
self,
input_tokens: list[int],
......@@ -1175,19 +1167,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."]
if self.audio_tower is None:
skip_prefixes.extend(["audio_tower."])
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
loader = AutoWeightsLoader(self, skip_prefixes=["talker.", "token2wav."])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
......
......@@ -35,7 +35,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature, Qwen2ForCausalLM
from transformers import BatchFeature
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
......@@ -1145,9 +1145,7 @@ class Qwen2_5_VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......@@ -1155,9 +1153,8 @@ class Qwen2_5_VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
......@@ -1447,9 +1444,6 @@ class Qwen2_5_VLForConditionalGeneration(
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
......@@ -1516,10 +1510,7 @@ class Qwen2_5_VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......@@ -1550,11 +1541,3 @@ class Qwen2_5_VLForConditionalGeneration(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen2ForCausalLM, "language_model"
......@@ -1233,9 +1233,7 @@ class Qwen2VLForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......@@ -1243,9 +1241,8 @@ class Qwen2VLForConditionalGeneration(
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
......@@ -1371,9 +1368,6 @@ class Qwen2VLForConditionalGeneration(
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
......@@ -1437,10 +1431,7 @@ class Qwen2VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......
......@@ -1277,11 +1277,16 @@ class Qwen3VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......@@ -1290,22 +1295,8 @@ class Qwen3VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
......@@ -1313,10 +1304,15 @@ class Qwen3VLForConditionalGeneration(
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
......@@ -1893,9 +1889,6 @@ class Qwen3VLForConditionalGeneration(
return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
......@@ -2076,10 +2069,7 @@ class Qwen3VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......@@ -2110,11 +2100,3 @@ class Qwen3VLForConditionalGeneration(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen3LLMForCausalLM, "language_model"
......@@ -424,11 +424,16 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......@@ -437,9 +442,21 @@ class Qwen3VLMoeForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# Whether to include the gate_up_proj mapping is determined by
# the language model.
self.packed_modules_mapping = (
......@@ -450,25 +467,5 @@ class Qwen3VLMoeForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
# Set MoE hyperparameters
self.set_moe_parameters()
......@@ -942,7 +942,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer(
config.vision_config,
None,
......@@ -967,12 +967,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config.hidden_size,
bias=config.projector_bias,
)
else:
self.vision_model = None
self.vit_downsampler = None
self.vit_downsampler2 = None
self.vit_large_projector = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......@@ -1071,9 +1067,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
)
return merged_image_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......@@ -1133,15 +1126,5 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
skip_prefixes = []
if self.vision_model is None and self.vit_large_projector is None:
skip_prefixes = [
"vision_model.",
"vit_downsampler.",
"vit_downsampler2.",
"vit_large_projector.",
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......@@ -504,7 +504,8 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = PerceptionEncoder(
config.vision_config,
get_act_fn(config.vision_config.hidden_act),
......@@ -521,10 +522,8 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
prefix=maybe_prefix(prefix, "vit_large_projector"),
disable_tp=self.use_data_parallel,
)
else:
self.vision_model = None
self.vit_large_projector = None
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......
......@@ -24,7 +24,11 @@ from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.model_executor.models.interfaces import (
LMMissingLayer,
TowerMissingLayer,
supports_any_eagle,
)
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
......@@ -250,7 +254,7 @@ class AutoWeightsLoader:
module: nn.Module,
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
if isinstance(module, PPMissingLayer):
if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)):
return
# Avoid infinite recursion since this function is typically
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment