Unverified Commit 2b8a38b6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Extend `collect_children` and `no_init_weights` contexts (#32757)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1bf1a34b
...@@ -2,11 +2,27 @@ ...@@ -2,11 +2,27 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from vllm.utils.collection_utils import swap_dict_values from vllm.utils.collection_utils import common_prefix, swap_dict_values
@pytest.mark.parametrize( @pytest.mark.parametrize(
"obj,key1,key2", ("inputs", "expected_output"),
[
([""], ""),
(["a"], "a"),
(["a", "b"], ""),
(["a", "ab"], "a"),
(["a", "ab", "b"], ""),
(["abc", "a", "ab"], "a"),
(["aba", "abc", "ab"], "ab"),
],
)
def test_common_prefix(inputs, expected_output):
assert common_prefix(inputs) == expected_output
@pytest.mark.parametrize(
("obj", "key1", "key2"),
[ [
# Tests for both keys exist # Tests for both keys exist
({1: "a", 2: "b"}, 1, 2), ({1: "a", 2: "b"}, 1, 2),
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -165,11 +165,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() ...@@ -165,11 +165,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import ( from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
as_embedding_model,
as_seq_cls_model,
try_create_mm_pooling_model_cls,
)
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
...@@ -189,15 +185,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -189,15 +185,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
) )
convert_type = model_config.convert_type convert_type = model_config.convert_type
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
logger.debug_once("Creating wrapper class to forward pooler.")
return converted, arch
else:
logger.debug_once("Attempting direct conversion.")
if convert_type == "none": if convert_type == "none":
pass pass
elif convert_type == "embed": elif convert_type == "embed":
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
...@@ -18,10 +16,12 @@ from vllm.transformers_utils.config import ( ...@@ -18,10 +16,12 @@ from vllm.transformers_utils.config import (
) )
from vllm.transformers_utils.repo_utils import get_hf_file_bytes from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from .interfaces import supports_multimodal
from .interfaces_base import VllmModelForPooling, is_pooling_model from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import Pooler
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
...@@ -124,20 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: ...@@ -124,20 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix return model_name + pooling_suffix
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: def _create_pooling_model_cls(orig_cls: _T) -> _T:
class CallVisitor(ast.NodeVisitor): # Lazy import
def __init__(self): from vllm.model_executor.layers.logits_processor import LogitsProcessor
self.calls = [] from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
self.generic_visit(node)
visitor = CallVisitor() from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
if "init_vllm_registered_model" not in visitor.calls:
return None
class ModelForPooling(orig_cls, VllmModelForPooling): class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True is_pooling_model = True
...@@ -149,90 +141,84 @@ def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: ...@@ -149,90 +141,84 @@ def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
prefix: str = "", prefix: str = "",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) with no_init_weights(
self,
self.pooler = self.get_language_model().pooler lambda mod: StageMissingLayer("output", mod),
targets=(LogitsProcessor, ParallelLMHead),
return ModelForPooling # type: ignore ):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# Used by SEQ_CLS_LOAD_METHODS
self.vllm_config = vllm_config
# If the model already defines a pooler instance, don't overwrite it
pooler = getattr(self, "pooler", None)
if not pooler and supports_multimodal(self):
# Try to get the pooler from the LM backbone
language_model = self.get_language_model()
if hasattr(language_model, "pooler"):
pooler = language_model.pooler
def _create_pooling_model_cls(orig_cls: _T) -> _T: if not pooler:
# Lazy import pooler = self._init_pooler(vllm_config, prefix=prefix)
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling): self.pooler = pooler
is_pooling_model = True
def __init__( def _init_pooler(
self, self,
*,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
prefix: str = "", prefix: str = "",
**kwargs: Any, ) -> "Pooler":
) -> None: raise NotImplementedError
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models
objects_to_clean = [self]
if language_model := getattr(self, "language_model", None):
objects_to_clean.append(language_model)
for obj in objects_to_clean:
for attr in ("lm_head", "logits_processor"):
if hasattr(obj, attr):
delattr(obj, attr)
# If the model already defines a pooler instance, don't overwrite it def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if not getattr(self, "pooler", None): params_dict = dict(self.named_parameters())
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): # We support loading from both `*ForCausalLM` and `*Model`
raise NotImplementedError candidate_prefixes = ["", "model."]
target_prefix = ""
def load_weights( seen_weights = list[tuple[str, torch.Tensor]]()
self, for name, loaded_weight in weights:
weights: Iterable[tuple[str, torch.Tensor]], seen_weights.append((name, loaded_weight))
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking
# For most pooling models: We have deleted this attribute, so don't load it.
# For converting an LLM into a seq cls model, we need the lm_head.
if not load_lm_head:
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
)
# If `*ForCausalLM` defines `load_weights` on the inner model try:
# and there are no other inner modules with parameters, target_prefix = next(
# we support loading from both `*Model` and `*ForCausalLM` prefix
if hasattr(self, "model") and hasattr(self.model, "load_weights"): for prefix in candidate_prefixes
# Whether only `self.model` contains parameters if prefix + name in params_dict
model_is_only_param = all( )
name == "model" or next(child.parameters(), None) is None break
for name, child in self.named_children() except StopIteration:
# The weight might not exist on the model
# (to be handled by AutoWeightsLoader)
pass
if target_prefix:
target_model = self
for attr in target_prefix.split("."):
if attr:
target_model = getattr(self, attr)
logger.info(
"Mapping weights to %s as they are "
"relative to this model instead of %s.",
target_model._get_name(),
self._get_name(),
) )
if model_is_only_param: mapped_weights = (
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) (target_prefix + name, weight)
weights = mapper.apply(weights) for name, weight in (*seen_weights, *weights)
)
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models def default_load_weights(weights):
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
load_weights = getattr(super(), "load_weights", default_load_weights)
return load_weights(mapped_weights)
return ModelForPooling # type: ignore return ModelForPooling # type: ignore
...@@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T: ...@@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T:
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import DispatchPooler
class ModelForEmbedding(_create_pooling_model_cls(cls)): class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config) return DispatchPooler.for_embedding(pooler_config)
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
...@@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification( class ModelForSequenceClassification(
_create_pooling_model_cls(cls), SupportsCrossEncoding _create_pooling_model_cls(cls), SupportsCrossEncoding
): ):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
text_config = vllm_config.model_config.hf_config.get_text_config() text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls( return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
pooler_config, classifier=self.score
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
hf_config = self.config hf_config = self.config
...@@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax( ...@@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax(
pooling_model_cls = next( pooling_model_cls = next(
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
) )
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True) loaded_weights = pooling_model_cls.load_weights(model, weights)
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
......
...@@ -44,11 +44,11 @@ from .interfaces import ( ...@@ -44,11 +44,11 @@ from .interfaces import (
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
TowerMissingLayer,
) )
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
StageMissingLayer,
WeightsMapper, WeightsMapper,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
...@@ -426,9 +426,9 @@ class BagelForConditionalGeneration( ...@@ -426,9 +426,9 @@ class BagelForConditionalGeneration(
hidden_size=llm_hidden_size, hidden_size=llm_hidden_size,
) )
else: else:
self.vit_model = TowerMissingLayer("image") self.vit_model = StageMissingLayer("image_tower")
self.connector = TowerMissingLayer("image") self.connector = StageMissingLayer("image_tower")
self.vit_pos_embed = TowerMissingLayer("image") self.vit_pos_embed = StageMissingLayer("image_tower")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -935,9 +935,20 @@ class ChameleonForConditionalGeneration( ...@@ -935,9 +935,20 @@ class ChameleonForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = ChameleonModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") with self._mark_composite_model(
) vllm_config,
language_targets=(
ChameleonDecoderLayer
if not self.config.swin_norm
else ChameleonSwinDecoderLayer
),
tower_targets={"image": ChameleonVQVAE},
):
self.model = ChameleonModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
...@@ -970,9 +981,6 @@ class ChameleonForConditionalGeneration( ...@@ -970,9 +981,6 @@ class ChameleonForConditionalGeneration(
resolve_bindings={"h": expected_h, "w": expected_w}, resolve_bindings={"h": expected_h, "w": expected_w},
) )
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -539,10 +539,7 @@ class Gemma3ForConditionalGeneration( ...@@ -539,10 +539,7 @@ class Gemma3ForConditionalGeneration(
) )
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(self.language_model, "logits_processor"): self.language_model.logits_processor.scale *= logit_scale
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors ...@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA, SupportsLoRA,
...@@ -591,11 +591,16 @@ class GLM4VForCausalLM( ...@@ -591,11 +591,16 @@ class GLM4VForCausalLM(
prefix: str = "", prefix: str = "",
transformer_type: type[GLM4VModel] = GLM4VModel, transformer_type: type[GLM4VModel] = GLM4VModel,
) -> None: ) -> None:
super().__init__( with self._mark_composite_model(
vllm_config=vllm_config, vllm_config,
prefix=prefix, language_targets=GLMTransformer,
transformer_type=transformer_type, tower_targets={"image": EVA2CLIPModel},
) ):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
self.transformer: GLM4VModel self.transformer: GLM4VModel
...@@ -752,9 +757,6 @@ class GLM4VForCausalLM( ...@@ -752,9 +757,6 @@ class GLM4VForCausalLM(
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.transformer
embed_input_ids = SupportsMultiModal.embed_input_ids embed_input_ids = SupportsMultiModal.embed_input_ids
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
......
...@@ -57,7 +57,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -57,7 +57,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import ( from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer, Idefics2VisionTransformer as Idefics3VisionTransformer,
) )
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
)
from .llama import LlamaModel from .llama import LlamaModel
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
...@@ -604,9 +608,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo ...@@ -604,9 +608,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = Idefics3Model( with self._mark_composite_model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config,
) language_targets=LlamaModel,
tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)},
):
self.model = Idefics3Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.image_token_id = self.config.image_token_id self.image_token_id = self.config.image_token_id
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -669,9 +680,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo ...@@ -669,9 +680,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())] return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence from collections.abc import Callable, Iterable, Mapping, MutableSequence
from contextlib import contextmanager, nullcontext from contextlib import ExitStack, contextmanager, nullcontext
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
ClassVar, ClassVar,
...@@ -25,6 +25,7 @@ from vllm.inputs import TokensPrompt ...@@ -25,6 +25,7 @@ from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from .interfaces_base import VllmModel, is_pooling_model from .interfaces_base import VllmModel, is_pooling_model
...@@ -70,46 +71,8 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: ...@@ -70,46 +71,8 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
return is_multimodal return is_multimodal
class LMMissingLayer(nn.Module): # Cache results of `SupportsMultiModal.get_language_model`
def make_empty_intermediate_tensors(self, *args, **kwargs): _language_model_by_module = dict[nn.Module, VllmModel]()
raise RuntimeError("This module should not be called in MM encoder-only mode")
def __call__(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
class TowerMissingLayer(nn.Module):
def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str):
modalities = {modalities}
super().__init__()
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(
f"This module should not be called when the following "
f"modalities are disabled: {self.modalities}"
)
@contextmanager
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
"""
def callback(module_, name, submodule):
if module_ is module:
return placeholder()
return submodule
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with torch.device("meta"):
yield
@runtime_checkable @runtime_checkable
...@@ -187,31 +150,61 @@ class SupportsMultiModal(Protocol): ...@@ -187,31 +150,61 @@ class SupportsMultiModal(Protocol):
Returns: Returns:
torch.nn.Module: The core language model component. torch.nn.Module: The core language model component.
""" """
# Cached
if self in _language_model_by_module:
return _language_model_by_module[self]
if self._language_model_names: if self._language_model_names:
return getattr(self, self._language_model_names[0]) mod = self
for attr in common_prefix(
[name.split(".") for name in self._language_model_names]
):
if attr:
mod = getattr(mod, attr)
if mod is not self and hasattr(mod, "embed_input_ids"):
_language_model_by_module[self] = mod
return mod
# Fallback
for mod in self.children():
if hasattr(mod, "embed_input_ids"):
_language_model_by_module[self] = mod
return mod
raise NotImplementedError( raise NotImplementedError(
f"No language model found in {type(self).__name__}! " f"No language model found in {type(self).__name__}! "
"You should initialize it inside `_mark_language_model`." "You should initialize it via `_mark_language_model`."
) )
@contextmanager @contextmanager
def _mark_language_model(self, vllm_config: VllmConfig): def _mark_language_model(
""" self,
Mark each child module that was assigned to this model vllm_config: VllmConfig,
during this context as a language model component. *,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
""" """
mm_config = vllm_config.model_config.multimodal_config Mark each child module that was assigned to this model during this context
as a language model component.
Language model components are automatically skipped in `--mm-encoder-only`
mode.
children_names = list[str]() If `targets` is set, instead include descendants that are an instance
of `targets`, even if they aren't direct children.
"""
from .utils import StageMissingLayer, collect_children, no_init_weights
def callback(module_, name, submodule): mm_config = vllm_config.model_config.multimodal_config
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117 with collect_children(self, targets=targets) as children_names: # noqa: SIM117
with ( with (
_no_init_weights(self, LMMissingLayer) no_init_weights(
self,
lambda mod: StageMissingLayer("language_model", mod),
targets=targets,
)
if mm_config.mm_encoder_only if mm_config.mm_encoder_only
else nullcontext() else nullcontext()
): ):
...@@ -220,25 +213,42 @@ class SupportsMultiModal(Protocol): ...@@ -220,25 +213,42 @@ class SupportsMultiModal(Protocol):
self._language_model_names = children_names self._language_model_names = children_names
@contextmanager @contextmanager
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str): def _mark_tower_model(
self,
vllm_config: VllmConfig,
modalities: set[str] | str,
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
""" """
Mark each child module that was assigned to this model Mark each child module that was assigned to this model during this context
during this context as a tower model component. as a tower model component.
Tower model components are automatically skipped when `--limit-mm-per-prompt`
is set to zero for all of their modalities.
If `targets` is set, instead include descendants that are an instance
of `targets`, even if they aren't direct children.
""" """
from .utils import StageMissingLayer, collect_children, no_init_weights
if isinstance(modalities, str): if isinstance(modalities, str):
modalities = {modalities} modalities = {modalities}
mm_config = vllm_config.model_config.multimodal_config if modalities == {"image", "video"}:
stage_name = "vision_tower"
children_names = list[str]() else:
stage_name = "_".join([*modalities, "tower"])
def callback(module_, name, submodule): mm_config = vllm_config.model_config.multimodal_config
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117 with collect_children(self, targets=targets) as children_names: # noqa: SIM117
with ( with (
_no_init_weights(self, lambda: TowerMissingLayer(modalities)) no_init_weights(
self,
lambda mod: StageMissingLayer(stage_name, mod),
targets=targets,
)
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities) if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
else nullcontext() else nullcontext()
): ):
...@@ -246,6 +256,37 @@ class SupportsMultiModal(Protocol): ...@@ -246,6 +256,37 @@ class SupportsMultiModal(Protocol):
self._tower_model_names = children_names self._tower_model_names = children_names
@contextmanager
def _mark_composite_model(
self,
vllm_config: VllmConfig,
*,
language_targets: type[nn.Module] | tuple[type[nn.Module], ...],
tower_targets: dict[str, type[nn.Module] | tuple[type[nn.Module], ...]],
):
"""
Composite wrapper over `_mark_language_model` and
`_mark_tower_model` by modality.
"""
with ExitStack() as stack:
stack.enter_context(
self._mark_language_model(
vllm_config,
targets=language_targets,
)
)
for modality, modality_targets in tower_targets.items():
stack.enter_context(
self._mark_tower_model(
vllm_config,
modality,
targets=modality_targets,
)
)
yield
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
""" """
Implement this function to enable LoRA support Implement this function to enable LoRA support
......
...@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors ...@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
from .utils import ( from .utils import (
StageMissingLayer,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
make_layers, make_layers,
maybe_prefix, maybe_prefix,
no_init_weights,
) )
...@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ...@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
prefix: str = "", prefix: str = "",
model_type: type[InternLM2Model] = InternLM2Model, model_type: type[InternLM2Model] = InternLM2Model,
): ):
super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type) with no_init_weights(
self,
for attr in ("output", "logits_processor"): lambda mod: StageMissingLayer("output", mod),
delattr(self, attr) targets=(LogitsProcessor, ParallelLMHead),
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
model_type=model_type,
)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype self.head_dtype = vllm_config.model_config.head_dtype
......
...@@ -1035,11 +1035,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1035,11 +1035,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) )
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vpm = vpm = self.init_vision_module( self.vpm = self.init_vision_module(
config, quant_config, prefix=maybe_prefix(prefix, "vpm") config, quant_config, prefix=maybe_prefix(prefix, "vpm")
) )
self.vision_dim = ( self.vision_dim = (
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim self.vpm.embed_dim
if self.version == (2, 0)
else self.vpm.embeddings.embed_dim
) )
self.embed_dim = self.config.hidden_size self.embed_dim = self.config.hidden_size
......
...@@ -70,20 +70,15 @@ from vllm.sequence import IntermediateTensors ...@@ -70,20 +70,15 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
LMMissingLayer,
MixtureOfExperts, MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
TowerMissingLayer,
) )
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
from .utils import ( from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
AutoWeightsLoader,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model from .vision import run_dp_sharded_vision_model
...@@ -1024,7 +1019,7 @@ class Llama4ForConditionalGeneration( ...@@ -1024,7 +1019,7 @@ class Llama4ForConditionalGeneration(
renamed = self._rename_weight_for_modelopt_checkpoint(name) renamed = self._rename_weight_for_modelopt_checkpoint(name)
attr = renamed.split(".", 1)[0] attr = renamed.split(".", 1)[0]
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)): if isinstance(getattr(self, attr), StageMissingLayer):
continue continue
if renamed.startswith("language_model."): if renamed.startswith("language_model."):
......
...@@ -1513,7 +1513,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1513,7 +1513,7 @@ class NemotronH_Nano_VL_V2(
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
...@@ -1542,7 +1542,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1542,7 +1542,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(), ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
) )
self.mlp1 = mlp1.to(language_model.config.dtype) self.mlp1 = mlp1.to(self.language_model.config.dtype)
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
......
...@@ -1025,12 +1025,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1025,12 +1025,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.mlp_AR = Projector(config, config.vision_config) self.mlp_AR = Projector(config, config.vision_config)
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = language_model = Ernie4_5ForCausalLM( self.language_model = Ernie4_5ForCausalLM(
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
for layer in language_model.model.layers: for layer in self.language_model.model.layers:
if not isinstance(layer, PPMissingLayer): if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = True layer.self_attn.rotary_emb.is_neox_style = True
......
...@@ -314,13 +314,14 @@ class PaliGemmaForConditionalGeneration( ...@@ -314,13 +314,14 @@ class PaliGemmaForConditionalGeneration(
config.text_config.architectures = ["Gemma2ForCausalLM"] config.text_config.architectures = ["Gemma2ForCausalLM"]
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
language_model.logits_processor.scale *= logit_scale self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -461,15 +461,16 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -461,15 +461,16 @@ class Qwen3VLMoeForConditionalGeneration(
] ]
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = language_model = Qwen3MoeLLMForCausalLM( self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
) )
# Whether to include the gate_up_proj mapping is determined by # Whether to include the gate_up_proj mapping is determined by
# the language model. # the language model.
self.packed_modules_mapping = ( self.packed_modules_mapping = (
self.packed_modules_mapping | language_model.packed_modules_mapping self.packed_modules_mapping | self.language_model.packed_modules_mapping
) )
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
...@@ -58,7 +58,7 @@ from .interfaces import ( ...@@ -58,7 +58,7 @@ from .interfaces import (
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
) )
from .qwen import QWenBaseModel, QWenModel from .qwen import QWenBaseModel, QWenBlock, QWenModel
class QwenImagePixelInputs(TensorSchema): class QwenImagePixelInputs(TensorSchema):
...@@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration( ...@@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration(
prefix: str = "", prefix: str = "",
transformer_type: type[QwenVLModel] = QwenVLModel, transformer_type: type[QwenVLModel] = QwenVLModel,
) -> None: ) -> None:
super().__init__( with self._mark_composite_model(
vllm_config=vllm_config, vllm_config,
prefix=prefix, language_targets=QWenBlock,
transformer_type=transformer_type, tower_targets={"image": VisionTransformer},
) ):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
self.transformer: QwenVLModel self.transformer: QwenVLModel
...@@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration( ...@@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration(
return self.transformer.visual(image_input["data"]) return self.transformer.visual(image_input["data"])
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from collections.abc import Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Protocol, overload from typing import Any, Literal, Protocol, overload
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
from torch.nn.modules.module import register_module_module_registration_hook
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -24,11 +26,7 @@ from vllm.model_executor.model_loader.online_quantization import ( ...@@ -24,11 +26,7 @@ from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights, support_quantized_model_reload_from_hp_weights,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import supports_any_eagle
LMMissingLayer,
TowerMissingLayer,
supports_any_eagle,
)
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -214,8 +212,8 @@ class AutoWeightsLoader: ...@@ -214,8 +212,8 @@ class AutoWeightsLoader:
continue continue
raise ValueError( raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' " f"Attempted to load nested weight {weight_qualname!r} "
f"into a single parameter '{base_prefix}'" f"into a single parameter {base_prefix!r}"
) )
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
...@@ -254,7 +252,7 @@ class AutoWeightsLoader: ...@@ -254,7 +252,7 @@ class AutoWeightsLoader:
module: nn.Module, module: nn.Module,
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]: ) -> Iterable[str]:
if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)): if isinstance(module, (StageMissingLayer, PPMissingLayer)):
return return
# Avoid infinite recursion since this function is typically # Avoid infinite recursion since this function is typically
...@@ -316,9 +314,14 @@ class AutoWeightsLoader: ...@@ -316,9 +314,14 @@ class AutoWeightsLoader:
continue continue
desc_param_keys = {
base_prefix + k for k, _ in module.named_parameters(recurse=True)
}
msg = ( msg = (
f"There is no module or parameter named '{prefix}' " f"There is no module or parameter named {prefix!r} "
f"in {type(self.module).__name__}" f"in {self.module._get_name()}. "
f"The available parameters belonging to {base_prefix} "
f"({module._get_name()}) are: {desc_param_keys}"
) )
raise ValueError(msg) raise ValueError(msg)
...@@ -496,6 +499,100 @@ def isin_list( ...@@ -496,6 +499,100 @@ def isin_list(
return torch.isin(elements, test_elements) return torch.isin(elements, test_elements)
class StageMissingLayer(nn.Module):
def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
super().__init__()
self.stage_name = stage_name
# Don't register this as a child module in order to
# avoid missing keys when loading weights
self.__dict__["module"] = module
def __getattr__(self, name: str):
return getattr(self.__dict__["module"], name)
def __call__(self, *args, **kwargs):
raise RuntimeError(f"{self} should not be called")
def extra_repr(self) -> str:
return f"stage_name={self.stage_name!r}"
@contextmanager
def collect_children(
module: nn.Module,
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Within this context, collect all direct child assignments to `module`,
returning a list of children names that is internally updated until the
context is exited.
If `targets` is set, instead collect descendents of `module`
that are an instance of `targets`, even if they aren't direct children.
"""
children_names = list[str]()
if targets is None:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if module_ is module:
children_names.append(name)
with register_module_module_registration_hook(hook):
yield children_names
else:
yield children_names
for name, module_ in module.named_modules():
if isinstance(module_, targets):
children_names.append(name)
@contextmanager
def no_init_weights(
module: nn.Module,
placeholder: Callable[[nn.Module], nn.Module],
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
If `targets` is set, instead prevent weight initialization and
replace assignments where the child is an instance of `targets`,
even if they aren't direct children of `module`.
"""
if targets is None:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if module_ is module:
return placeholder(submodule)
return submodule
with register_module_module_registration_hook(hook), torch.device("meta"):
yield
else:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if isinstance(module_, targets):
submodule.to("meta") # Free memory
if isinstance(submodule, targets):
submodule.to("meta") # Free memory
return placeholder(submodule)
return submodule
# Not all descendents are targeted, so we can't use a blanket
# `torch.device("meta")` context
with register_module_module_registration_hook(hook):
yield
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: ... def __call__(self, prefix: str) -> torch.nn.Module: ...
...@@ -627,7 +724,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: ...@@ -627,7 +724,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
missing_layer_names = [] missing_layer_names = []
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, PPMissingLayer): if isinstance(module, (StageMissingLayer, PPMissingLayer)):
# NOTE: the trailing dot is used to match the prefix of the layer. # NOTE: the trailing dot is used to match the prefix of the layer.
# without the dot, we could match a layer that is not missing, # without the dot, we could match a layer that is not missing,
# e.g., 'encoder.layer.1' would match 'encoder.layer.11' # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
...@@ -639,7 +736,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: ...@@ -639,7 +736,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model.""" """Check if a parameter is missing in a pipeline parallel model."""
if isinstance(model, PPMissingLayer): if isinstance(model, (StageMissingLayer, PPMissingLayer)):
return True return True
return any( return any(
......
...@@ -909,7 +909,12 @@ class WhisperForConditionalGeneration( ...@@ -909,7 +909,12 @@ class WhisperForConditionalGeneration(
self.config = config self.config = config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) with self._mark_composite_model(
vllm_config,
language_targets=WhisperDecoder,
tower_targets={"audio": WhisperEncoder},
):
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.proj_out = ParallelLMHead( self.proj_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
...@@ -937,9 +942,6 @@ class WhisperForConditionalGeneration( ...@@ -937,9 +942,6 @@ class WhisperForConditionalGeneration(
) )
return decoder_outputs return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface. # Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
......
...@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module. ...@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module.
""" """
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from typing import Generic, Literal, TypeVar from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never from typing_extensions import TypeIs, assert_never, overload
T = TypeVar("T") T = TypeVar("T")
...@@ -74,6 +74,34 @@ def is_list_of( ...@@ -74,6 +74,34 @@ def is_list_of(
assert_never(check) assert_never(check)
@overload
def common_prefix(items: Sequence[str]) -> str: ...
@overload
def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ...
def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str:
"""Find the longest prefix common to all items."""
if len(items) == 0:
return []
if len(items) == 1:
return items[0]
shortest = min(items, key=len)
if not shortest:
return shortest[:0]
for match_len in range(1, len(shortest) + 1):
match = shortest[:match_len]
for item in items:
if item[:match_len] != match:
return shortest[: match_len - 1]
return shortest
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]: def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst.""" """Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size): for i in range(0, len(lst), chunk_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment