Unverified Commit 4753f3bf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use context managers for encoder- and LM-only mode (#32605)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6c01ffb8
...@@ -38,7 +38,7 @@ Encoder engines should be launched with the following flags: ...@@ -38,7 +38,7 @@ Encoder engines should be launched with the following flags:
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager. - `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
- `--convert "mm_encoder_only"` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must implement the `get_language_model_spec` interface.** - `--mm-encoder-only` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must initialize the language component inside the context of `SupportsMultiModal._mark_language_model`.**
## Local media inputs ## Local media inputs
......
...@@ -306,6 +306,7 @@ class ModelConfig: ...@@ -306,6 +306,7 @@ class ModelConfig:
mm_processor_cache_gb: InitVar[float | None] = None mm_processor_cache_gb: InitVar[float | None] = None
mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_only: InitVar[bool | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None interleave_mm_strings: InitVar[bool | None] = None
...@@ -420,6 +421,7 @@ class ModelConfig: ...@@ -420,6 +421,7 @@ class ModelConfig:
mm_processor_cache_gb: float | None, mm_processor_cache_gb: float | None,
mm_processor_cache_type: MMCacheType | None, mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None, mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_only: bool | None,
mm_encoder_tp_mode: MMEncoderTPMode | None, mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: AttentionBackendEnum | str | None, mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None, interleave_mm_strings: bool | None,
...@@ -497,6 +499,15 @@ class ModelConfig: ...@@ -497,6 +499,15 @@ class ModelConfig:
) )
self.model_arch_config = self.get_model_arch_config() self.model_arch_config = self.get_model_arch_config()
if self.convert == "mm_encoder_only":
logger.warning_once(
"`--convert mm_encoder_only` is deprecated and "
"will be removed in v0.15. "
"Please use --mm-encoder-only` instead."
)
mm_encoder_only = True
self.convert = "none"
architectures = self.architectures architectures = self.architectures
registry = self.registry registry = self.registry
is_generative_model = registry.is_text_generation_model(architectures, self) is_generative_model = registry.is_text_generation_model(architectures, self)
...@@ -583,6 +594,7 @@ class ModelConfig: ...@@ -583,6 +594,7 @@ class ModelConfig:
mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_gb=mm_processor_cache_gb,
mm_processor_cache_type=mm_processor_cache_type, mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_encoder_only=mm_encoder_only,
mm_encoder_tp_mode=mm_encoder_tp_mode, mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend, mm_encoder_attn_backend=mm_encoder_attn_backend,
interleave_mm_strings=interleave_mm_strings, interleave_mm_strings=interleave_mm_strings,
......
...@@ -108,6 +108,12 @@ class MultiModalConfig: ...@@ -108,6 +108,12 @@ class MultiModalConfig:
"""Size limit (in MiB) for each object stored in the multi-modal processor """Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`.""" `"shm"`."""
mm_encoder_only: bool = False
"""
When enabled, skips the language component of the model.
This is usually only valid in disaggregated Encoder process.
"""
mm_encoder_tp_mode: MMEncoderTPMode = "weights" mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using tensor """Indicates how to optimize multi-modal encoder inference using tensor
parallelism (TP). parallelism (TP).
......
...@@ -467,6 +467,7 @@ class EngineArgs: ...@@ -467,6 +467,7 @@ class EngineArgs:
mm_shm_cache_max_object_size_mb: int = ( mm_shm_cache_max_object_size_mb: int = (
MultiModalConfig.mm_shm_cache_max_object_size_mb MultiModalConfig.mm_shm_cache_max_object_size_mb
) )
mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
MultiModalConfig.mm_encoder_attn_backend MultiModalConfig.mm_encoder_attn_backend
...@@ -973,6 +974,9 @@ class EngineArgs: ...@@ -973,6 +974,9 @@ class EngineArgs:
"--mm-shm-cache-max-object-size-mb", "--mm-shm-cache-max-object-size-mb",
**multimodal_kwargs["mm_shm_cache_max_object_size_mb"], **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
) )
multimodal_group.add_argument(
"--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
)
multimodal_group.add_argument( multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
) )
...@@ -1256,6 +1260,7 @@ class EngineArgs: ...@@ -1256,6 +1260,7 @@ class EngineArgs:
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_processor_cache_type=self.mm_processor_cache_type, mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
mm_encoder_only=self.mm_encoder_only,
mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend, mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config, pooler_config=self.pooler_config,
......
...@@ -189,9 +189,7 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -189,9 +189,7 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
) )
convert_type = model_config.convert_type convert_type = model_config.convert_type
if convert_type not in ["none", "mm_encoder_only"] and supports_multimodal( if convert_type != "none" and supports_multimodal(model_cls):
model_cls
):
logger.debug_once("Detected conversion of Multi Modal model.") logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls) converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None: if converted is not None:
...@@ -202,11 +200,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -202,11 +200,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
if convert_type == "none": if convert_type == "none":
pass pass
elif convert_type == "mm_encoder_only":
logger.debug_once("Converting to mm encoder only model.")
from vllm.model_executor.models.adapters import as_mm_encoder_only_model
model_cls = as_mm_encoder_only_model(model_cls)
elif convert_type == "embed": elif convert_type == "embed":
logger.debug_once("Converting to embedding model.") logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls) model_cls = as_embedding_model(model_cls)
......
...@@ -529,64 +529,3 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): ...@@ -529,64 +529,3 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
method = getattr(hf_config, "method", getattr(text_config, "method", None)) method = getattr(hf_config, "method", getattr(text_config, "method", None))
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights) return SEQ_CLS_LOAD_METHODS[method](model, weights)
def as_mm_encoder_only_model(cls: _T) -> _T:
"""
Subclass an existing vLLM vl model to support mm encoder only for
EPD encoder instances.
"""
if not hasattr(cls, "embed_multimodal"):
# Submodel case: return the original class.
return cls
if not hasattr(cls, "get_language_model_spec"):
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
lm_model_cls, lm_attr = cls.get_language_model_spec()
if lm_model_cls is None or lm_attr is None:
raise TypeError(
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
)
class DummyLM(nn.Module):
def __init__(self, *args, **kwargs):
self.make_empty_intermediate_tensors = None
class ModelForMMEncoderOnly(cls):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
self.is_mm_encoder_only_model = True
origin_init = lm_model_cls.__init__
try:
lm_model_cls.__init__ = DummyLM.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
if hasattr(self, lm_attr):
delattr(self, lm_attr)
finally:
lm_model_cls.__init__ = origin_init
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
from .utils import AutoWeightsLoader
origin_init_ = AutoWeightsLoader.__init__
def _new_init_(self, *args, **kwargs):
origin_init_(self, *args, **kwargs)
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
try:
AutoWeightsLoader.__init__ = _new_init_
result = super().load_weights(weights)
finally:
AutoWeightsLoader.__init__ = origin_init_
return result
return ModelForMMEncoderOnly # type: ignore
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence from collections.abc import Callable, Iterable, Mapping, MutableSequence
from contextlib import contextmanager, nullcontext
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
ClassVar, ClassVar,
...@@ -69,6 +70,46 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: ...@@ -69,6 +70,46 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
return is_multimodal return is_multimodal
class LMMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
def __call__(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str]) -> None:
super().__init__()
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(f"The following modalities are disabled: {self.modalities}")
@contextmanager
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
"""
def callback(module_, name, submodule):
if module_ is module:
return placeholder()
return submodule
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with torch.device("meta"):
yield
@runtime_checkable @runtime_checkable
class SupportsMultiModal(Protocol): class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models.""" """The interface required for all multi-modal models."""
...@@ -105,6 +146,16 @@ class SupportsMultiModal(Protocol): ...@@ -105,6 +146,16 @@ class SupportsMultiModal(Protocol):
Set internally by `MultiModalRegistry.register_processor`. Set internally by `MultiModalRegistry.register_processor`.
""" """
_language_model_names: list[str] = []
"""
Set internally by `_mark_language_model`.
"""
_tower_model_names: list[str] = []
"""
Set internally by `_mark_tower_model`.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
""" """
...@@ -134,7 +185,64 @@ class SupportsMultiModal(Protocol): ...@@ -134,7 +185,64 @@ class SupportsMultiModal(Protocol):
Returns: Returns:
torch.nn.Module: The core language model component. torch.nn.Module: The core language model component.
""" """
... if self._language_model_names:
return getattr(self, self._language_model_names[0])
raise NotImplementedError(
f"No language model found in {type(self).__name__}! "
"You should initialize it inside `_mark_language_model`."
)
@contextmanager
def _mark_language_model(self, vllm_config: VllmConfig):
"""
Mark each child module that was assigned to this model
during this context as a language model component.
"""
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, LMMissingLayer)
if mm_config.mm_encoder_only
else nullcontext()
):
yield
self._language_model_names = children_names
@contextmanager
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str):
"""
Mark each child module that was assigned to this model
during this context as a tower model component.
"""
if isinstance(modalities, str):
modalities = {modalities}
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, lambda: TowerMissingLayer(modalities))
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
else nullcontext()
):
yield
self._tower_model_names = children_names
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
""" """
...@@ -154,14 +262,6 @@ class SupportsMultiModal(Protocol): ...@@ -154,14 +262,6 @@ class SupportsMultiModal(Protocol):
""" """
... ...
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return None, None
@overload @overload
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ... def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
...@@ -299,10 +399,6 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool: ...@@ -299,10 +399,6 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False) return getattr(model, "supports_encoder_tp_data", False)
def supports_mm_encoder_only(model: type[object] | object) -> bool:
return getattr(model, "is_mm_encoder_only_model", False)
@overload @overload
def supports_multimodal_pruning( def supports_multimodal_pruning(
model: type[object], model: type[object],
......
...@@ -550,8 +550,7 @@ class LlavaForConditionalGeneration( ...@@ -550,8 +550,7 @@ class LlavaForConditionalGeneration(
): ):
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. with self._mark_tower_model(vllm_config, "image"):
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
...@@ -567,15 +566,13 @@ class LlavaForConditionalGeneration( ...@@ -567,15 +566,13 @@ class LlavaForConditionalGeneration(
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
else:
self.vision_tower = None with self._mark_language_model(vllm_config):
self.multi_modal_projector = None self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
self.language_model = init_vllm_registered_model( hf_config=config.text_config,
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"),
hf_config=config.text_config, )
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -631,8 +628,6 @@ class LlavaForConditionalGeneration( ...@@ -631,8 +628,6 @@ class LlavaForConditionalGeneration(
self, self,
inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs, inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values) return self._image_pixels_to_features(self.vision_tower, pixel_values)
...@@ -644,7 +639,6 @@ class LlavaForConditionalGeneration( ...@@ -644,7 +639,6 @@ class LlavaForConditionalGeneration(
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor): if isinstance(image_features, torch.Tensor):
...@@ -656,9 +650,6 @@ class LlavaForConditionalGeneration( ...@@ -656,9 +650,6 @@ class LlavaForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes) image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -727,11 +718,7 @@ class LlavaForConditionalGeneration( ...@@ -727,11 +718,7 @@ class LlavaForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
......
...@@ -457,8 +457,7 @@ class Mistral3ForConditionalGeneration( ...@@ -457,8 +457,7 @@ class Mistral3ForConditionalGeneration(
): ):
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. with self._mark_tower_model(vllm_config, "image"):
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
...@@ -476,15 +475,13 @@ class Mistral3ForConditionalGeneration( ...@@ -476,15 +475,13 @@ class Mistral3ForConditionalGeneration(
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -534,9 +531,6 @@ class Mistral3ForConditionalGeneration( ...@@ -534,9 +531,6 @@ class Mistral3ForConditionalGeneration(
image_embeds = (image_embeds,) image_embeds = (image_embeds,)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -607,11 +601,7 @@ class Mistral3ForConditionalGeneration( ...@@ -607,11 +601,7 @@ class Mistral3ForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
......
...@@ -70,12 +70,14 @@ from vllm.sequence import IntermediateTensors ...@@ -70,12 +70,14 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
LMMissingLayer,
MixtureOfExperts, MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
TowerMissingLayer,
) )
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
from .utils import ( from .utils import (
...@@ -773,7 +775,8 @@ class Llama4ForConditionalGeneration( ...@@ -773,7 +775,8 @@ class Llama4ForConditionalGeneration(
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
with ( with (
...@@ -792,16 +795,15 @@ class Llama4ForConditionalGeneration( ...@@ -792,16 +795,15 @@ class Llama4ForConditionalGeneration(
quant_config=None, quant_config=None,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
else:
self.vision_model = None with self._mark_language_model(vllm_config):
self.multi_modal_projector = None self.language_model = initialize_model(
self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(
vllm_config=vllm_config.with_hf_config( config.text_config, ["LlamaForCausalLM"]
config.text_config, ["LlamaForCausalLM"] ),
), prefix=maybe_prefix(prefix, "language_model"),
prefix=maybe_prefix(prefix, "language_model"), model_class=Llama4ForCausalLM,
model_class=Llama4ForCausalLM, )
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -892,9 +894,6 @@ class Llama4ForConditionalGeneration( ...@@ -892,9 +894,6 @@ class Llama4ForConditionalGeneration(
for img in vision_embeddings_flat.split(patches_per_image, dim=0) for img in vision_embeddings_flat.split(patches_per_image, dim=0)
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -1024,6 +1023,10 @@ class Llama4ForConditionalGeneration( ...@@ -1024,6 +1023,10 @@ class Llama4ForConditionalGeneration(
for name, weight in weights: for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name) renamed = self._rename_weight_for_modelopt_checkpoint(name)
attr = renamed.split(".", 1)[0]
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
continue
if renamed.startswith("language_model."): if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight)) language_model_weights.append((renamed, weight))
else: else:
...@@ -1133,10 +1136,6 @@ class Llama4ForConditionalGeneration( ...@@ -1133,10 +1136,6 @@ class Llama4ForConditionalGeneration(
weights weights
) )
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
other_weights = []
# Handle expert scale parameters # Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = ( regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights, params_dict) self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
......
...@@ -239,7 +239,7 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -239,7 +239,7 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config.is_multimodal_pruning_enabled() multimodal_config.is_multimodal_pruning_enabled()
) )
if multimodal_config.get_limit_per_prompt("image"): with self._mark_tower_model(vllm_config, "image"):
self.visual = OpenCUAVisionTransformer( self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
...@@ -247,15 +247,14 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -247,15 +247,14 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
else:
self.visual = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
architectures=["Qwen2ForCausalLM"], prefix=maybe_prefix(prefix, "language_model"),
) architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -398,13 +398,14 @@ class PixtralForConditionalGeneration( ...@@ -398,13 +398,14 @@ class PixtralForConditionalGeneration(
self.vision_args = VisionEncoderArgs(**vision_args) self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM # init MistralForCausalLM
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
if multimodal_config.get_limit_per_prompt("image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_encoder = VisionTransformer(self.vision_args) self.vision_encoder = VisionTransformer(self.vision_args)
self.pre_mm_projector_norm = ( self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5) RMSNorm(self.vision_args.hidden_size, eps=1e-5)
...@@ -423,11 +424,6 @@ class PixtralForConditionalGeneration( ...@@ -423,11 +424,6 @@ class PixtralForConditionalGeneration(
self.vision_language_adapter = VisionLanguageAdapter( self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size self.vision_args, dim=config.text_config.hidden_size
) )
else:
self.vision_encoder = None
self.pre_mm_projector_norm = None
self.patch_merger = None
self.vision_language_adapter = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -449,10 +445,6 @@ class PixtralForConditionalGeneration( ...@@ -449,10 +445,6 @@ class PixtralForConditionalGeneration(
self, self,
image_input: PixtralImagePixelInputs, image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
assert (
self.vision_encoder is not None and self.vision_language_adapter is not None
)
images = image_input["images"] images = image_input["images"]
image_features = self.vision_encoder(images) image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features] feature_sizes = [image_feature.shape[0] for image_feature in image_features]
...@@ -477,9 +469,6 @@ class PixtralForConditionalGeneration( ...@@ -477,9 +469,6 @@ class PixtralForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes) image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -822,6 +822,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -822,6 +822,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config self.config = thinker_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config
# force "use_flash_attention_2=True" to audio tower to align # force "use_flash_attention_2=True" to audio tower to align
# the results. # the results.
...@@ -836,14 +837,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -836,14 +837,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"in the audio tower part." "in the audio tower part."
) )
if multimodal_config.get_limit_per_prompt("audio"): with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
else:
self.audio_tower = None
if multimodal_config.get_limit_per_prompt( with self._mark_tower_model(vllm_config, {"image", "video"}):
"image"
) or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer( self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config, vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
...@@ -851,16 +848,14 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -851,16 +848,14 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
else:
self.visual = None
self.quant_config = quant_config with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
hf_config=thinker_config.text_config, hf_config=thinker_config.text_config,
architectures=["Qwen2ForCausalLM"], architectures=["Qwen2ForCausalLM"],
) )
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -895,9 +890,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -895,9 +890,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
...@@ -1175,19 +1167,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1175,19 +1167,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."] loader = AutoWeightsLoader(self, skip_prefixes=["talker.", "token2wav."])
if self.audio_tower is None: return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes.extend(["audio_tower."])
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
...@@ -35,7 +35,7 @@ import numpy as np ...@@ -35,7 +35,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BatchFeature, Qwen2ForCausalLM from transformers import BatchFeature
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLConfig,
...@@ -1145,9 +1145,7 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1145,9 +1145,7 @@ class Qwen2_5_VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled() multimodal_config.is_multimodal_pruning_enabled()
) )
if multimodal_config.get_limit_per_prompt( with self._mark_tower_model(vllm_config, {"image", "video"}):
"image"
) or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer( self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
...@@ -1155,14 +1153,13 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1155,14 +1153,13 @@ class Qwen2_5_VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
else:
self.visual = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
architectures=["Qwen2ForCausalLM"], prefix=maybe_prefix(prefix, "language_model"),
) architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1447,9 +1444,6 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1447,9 +1444,6 @@ class Qwen2_5_VLForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
...@@ -1516,10 +1510,7 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1516,10 +1510,7 @@ class Qwen2_5_VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
...@@ -1550,11 +1541,3 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1550,11 +1541,3 @@ class Qwen2_5_VLForConditionalGeneration(
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen2ForCausalLM, "language_model"
...@@ -1233,9 +1233,7 @@ class Qwen2VLForConditionalGeneration( ...@@ -1233,9 +1233,7 @@ class Qwen2VLForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt( with self._mark_tower_model(vllm_config, {"image", "video"}):
"image"
) or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2VisionTransformer( self.visual = Qwen2VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
...@@ -1243,14 +1241,13 @@ class Qwen2VLForConditionalGeneration( ...@@ -1243,14 +1241,13 @@ class Qwen2VLForConditionalGeneration(
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
else:
self.visual = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
architectures=["Qwen2ForCausalLM"], prefix=maybe_prefix(prefix, "language_model"),
) architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1371,9 +1368,6 @@ class Qwen2VLForConditionalGeneration( ...@@ -1371,9 +1368,6 @@ class Qwen2VLForConditionalGeneration(
return modalities return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
...@@ -1437,10 +1431,7 @@ class Qwen2VLForConditionalGeneration( ...@@ -1437,10 +1431,7 @@ class Qwen2VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
......
...@@ -1277,11 +1277,16 @@ class Qwen3VLForConditionalGeneration( ...@@ -1277,11 +1277,16 @@ class Qwen3VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled() multimodal_config.is_multimodal_pruning_enabled()
) )
if not multimodal_config.get_limit_per_prompt( self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
"image" self.deepstack_num_level = (
) and not multimodal_config.get_limit_per_prompt("video"): len(config.vision_config.deepstack_visual_indexes)
self.visual = None if self.use_deepstack
else: else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
...@@ -1290,34 +1295,25 @@ class Qwen3VLForConditionalGeneration( ...@@ -1290,34 +1295,25 @@ class Qwen3VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.language_model = Qwen3LLMForCausalLM( # register buffer for deepstack
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") if self.use_deepstack:
) self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers self.language_model.model.aux_hidden_state_layers = layers
...@@ -1893,9 +1889,6 @@ class Qwen3VLForConditionalGeneration( ...@@ -1893,9 +1889,6 @@ class Qwen3VLForConditionalGeneration(
return torch.from_numpy(llm_positions), mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
...@@ -2076,10 +2069,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -2076,10 +2069,7 @@ class Qwen3VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
...@@ -2110,11 +2100,3 @@ class Qwen3VLForConditionalGeneration( ...@@ -2110,11 +2100,3 @@ class Qwen3VLForConditionalGeneration(
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen3LLMForCausalLM, "language_model"
...@@ -424,11 +424,16 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -424,11 +424,16 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled() multimodal_config.is_multimodal_pruning_enabled()
) )
if not multimodal_config.get_limit_per_prompt( self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
"image" self.deepstack_num_level = (
) and not multimodal_config.get_limit_per_prompt("video"): len(config.vision_config.deepstack_visual_indexes)
self.visual = None if self.use_deepstack
else: else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
...@@ -437,9 +442,21 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -437,9 +442,21 @@ class Qwen3VLMoeForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.language_model = Qwen3MoeLLMForCausalLM( # register buffer for deepstack
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") if self.use_deepstack:
) self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# Whether to include the gate_up_proj mapping is determined by # Whether to include the gate_up_proj mapping is determined by
# the language model. # the language model.
self.packed_modules_mapping = ( self.packed_modules_mapping = (
...@@ -450,25 +467,5 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -450,25 +467,5 @@ class Qwen3VLMoeForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
# Set MoE hyperparameters # Set MoE hyperparameters
self.set_moe_parameters() self.set_moe_parameters()
...@@ -942,7 +942,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -942,7 +942,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer( self.vision_model = Step3VisionTransformer(
config.vision_config, config.vision_config,
None, None,
...@@ -967,17 +967,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -967,17 +967,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config.hidden_size, config.hidden_size,
bias=config.projector_bias, bias=config.projector_bias,
) )
else:
self.vision_model = None with self._mark_language_model(vllm_config):
self.vit_downsampler = None self.language_model = init_vllm_registered_model(
self.vit_downsampler2 = None vllm_config=vllm_config,
self.vit_large_projector = None hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
self.language_model = init_vllm_registered_model( )
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1071,9 +1067,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1071,9 +1067,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
) )
return merged_image_features return merged_image_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -1133,15 +1126,5 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1133,15 +1126,5 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.vision_model is None and self.vit_large_projector is None: return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes = [
"vision_model.",
"vit_downsampler.",
"vit_downsampler2.",
"vit_large_projector.",
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
...@@ -504,7 +504,8 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration): ...@@ -504,7 +504,8 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = PerceptionEncoder( self.vision_model = PerceptionEncoder(
config.vision_config, config.vision_config,
get_act_fn(config.vision_config.hidden_act), get_act_fn(config.vision_config.hidden_act),
...@@ -521,15 +522,13 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration): ...@@ -521,15 +522,13 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
prefix=maybe_prefix(prefix, "vit_large_projector"), prefix=maybe_prefix(prefix, "vit_large_projector"),
disable_tp=self.use_data_parallel, disable_tp=self.use_data_parallel,
) )
else:
self.vision_model = None
self.vit_large_projector = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -24,7 +24,11 @@ from vllm.model_executor.model_loader.online_quantization import ( ...@@ -24,7 +24,11 @@ from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights, support_quantized_model_reload_from_hp_weights,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.model_executor.models.interfaces import (
LMMissingLayer,
TowerMissingLayer,
supports_any_eagle,
)
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -250,7 +254,7 @@ class AutoWeightsLoader: ...@@ -250,7 +254,7 @@ class AutoWeightsLoader:
module: nn.Module, module: nn.Module,
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]: ) -> Iterable[str]:
if isinstance(module, PPMissingLayer): if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)):
return return
# Avoid infinite recursion since this function is typically # Avoid infinite recursion since this function is typically
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment