"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "50b788a17a8a059dad354640bebcaaee7dd72f3f"
Unverified Commit 7f1bcd18 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[3/N] Initialize MM components in context managers (I-L) (#32650)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8be263c3
...@@ -547,20 +547,20 @@ class InternS1ForConditionalGeneration( ...@@ -547,20 +547,20 @@ class InternS1ForConditionalGeneration(
) )
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
self.llm_arch_name = config.text_config.architectures[0] with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = self._init_vision_model( self.vision_tower = self._init_vision_model(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.multi_modal_projector = self._init_mlp1(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = self._init_mlp1(config) with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None self.img_context_token_id = None
self.video_context_token_id = None self.video_context_token_id = None
...@@ -699,8 +699,6 @@ class InternS1ForConditionalGeneration( ...@@ -699,8 +699,6 @@ class InternS1ForConditionalGeneration(
): ):
return image_input["data"] return image_input["data"]
assert self.vision_tower is not None
image_embeds = self.extract_feature(image_input["pixel_values"]) image_embeds = self.extract_feature(image_input["pixel_values"])
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -737,9 +735,6 @@ class InternS1ForConditionalGeneration( ...@@ -737,9 +735,6 @@ class InternS1ForConditionalGeneration(
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
......
...@@ -1092,22 +1092,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1092,22 +1092,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0] llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == "InternLM2VEForCausalLM" self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model( with self._mark_tower_model(vllm_config, {"image", "video"}):
vllm_config=vllm_config, self.vision_model = self._init_vision_model(
hf_config=config.text_config, config,
prefix=maybe_prefix(prefix, "language_model"), quant_config=quant_config,
) is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(config)
self.mlp1 = self._init_mlp1(config) with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None self.img_context_token_id = None
self.video_context_token_id = None self.video_context_token_id = None
...@@ -1281,8 +1283,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1281,8 +1283,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
): ):
return image_input["data"] return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"]) image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -1325,9 +1325,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1325,9 +1325,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
......
...@@ -1342,11 +1342,14 @@ class IsaacForConditionalGeneration( ...@@ -1342,11 +1342,14 @@ class IsaacForConditionalGeneration(
"mrope_interleaved", rope_scaling["mrope_interleaved"] "mrope_interleaved", rope_scaling["mrope_interleaved"]
) )
target_cfg.rope_parameters = rope_parameters target_cfg.rope_parameters = rope_parameters
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, with self._mark_language_model(vllm_config):
architectures=["Qwen3ForCausalLM"], self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
) architectures=["Qwen3ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
...@@ -1363,14 +1366,16 @@ class IsaacForConditionalGeneration( ...@@ -1363,14 +1366,16 @@ class IsaacForConditionalGeneration(
vision_cfg._attn_implementation = attn_impl vision_cfg._attn_implementation = attn_impl
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
self.vision_embedding = IsaacVisionEmbedding(
vision_cfg=vision_cfg, with self._mark_tower_model(vllm_config, "image"):
hidden_dim=hidden_dim, self.vision_embedding = IsaacVisionEmbedding(
output_dim=config.hidden_size, vision_cfg=vision_cfg,
quant_config=quant_config, hidden_dim=hidden_dim,
multimodal_config=self.multimodal_config, output_dim=config.hidden_size,
prefix=maybe_prefix(prefix, "vision_embedding"), quant_config=quant_config,
) multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "vision_embedding"),
)
def iter_mm_grid_hw( def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
...@@ -1457,18 +1462,6 @@ class IsaacForConditionalGeneration( ...@@ -1457,18 +1462,6 @@ class IsaacForConditionalGeneration(
return () return ()
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
# Backward compatibility for older runners.
embeddings = self.embed_multimodal(**kwargs)
if not embeddings:
return []
return embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -586,16 +586,21 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -586,16 +586,21 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.config = config self.config = config
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config) with self._mark_tower_model(vllm_config, "image"):
self.abstractor = DynamicCAbstractor( self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
config.projector_config, num_input_tokens=self.vision_model.get_num_tokens() self.abstractor = DynamicCAbstractor(
) config.projector_config,
self.language_model = init_vllm_registered_model( num_input_tokens=self.vision_model.get_num_tokens(),
vllm_config=vllm_config, )
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"), with self._mark_language_model(vllm_config):
architectures=["LlamaForCausalLM"], self.language_model = init_vllm_registered_model(
) vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
...@@ -718,9 +723,6 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -718,9 +723,6 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas) visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
return visual_embeds return visual_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -1242,7 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): ...@@ -1242,7 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return _keye_field_config(hf_inputs) return _keye_field_config(hf_inputs)
class BaseKeyeModule(nn.Module): class BaseKeyeModule(nn.Module, SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -1280,25 +1280,26 @@ class BaseKeyeModule(nn.Module): ...@@ -1280,25 +1280,26 @@ class BaseKeyeModule(nn.Module):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.visual = KeyeSiglipVisionModel( with self._mark_tower_model(vllm_config, {"image", "video"}):
config.vision_config, self.visual = KeyeSiglipVisionModel(
quant_config=quant_config, config.vision_config,
multimodal_config=multimodal_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), multimodal_config=multimodal_config,
) prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = self._build_projector( self.mlp_AR = self._build_projector(
config, config,
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "mlp_AR"), prefix=maybe_prefix(prefix, "mlp_AR"),
) )
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
architectures=["Qwen3ForCausalLM"], prefix=maybe_prefix(prefix, "language_model"),
) architectures=["Qwen3ForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1312,7 +1313,7 @@ class BaseKeyeModule(nn.Module): ...@@ -1312,7 +1313,7 @@ class BaseKeyeModule(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
raise ValueError("Need projector") raise NotImplementedError("Need projector")
def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]: def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list() siglip_position_ids = list()
...@@ -1429,9 +1430,6 @@ class BaseKeyeModule(nn.Module): ...@@ -1429,9 +1430,6 @@ class BaseKeyeModule(nn.Module):
return modalities return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
......
...@@ -42,7 +42,6 @@ ...@@ -42,7 +42,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import copy
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
...@@ -50,23 +49,12 @@ from typing import Annotated, Any, Literal ...@@ -50,23 +49,12 @@ from typing import Annotated, Any, Literal
import torch import torch
from torch import nn from torch import nn
from transformers import BatchFeature, DeepseekV2Config from transformers import BatchFeature
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -92,7 +80,7 @@ from vllm.sequence import IntermediateTensors ...@@ -92,7 +80,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_mrope_vision_model from .vision import run_dp_sharded_mrope_vision_model
...@@ -315,48 +303,41 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -315,48 +303,41 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__() super().__init__()
model_config = vllm_config.model_config model_config = vllm_config.model_config
config: KimiVLConfig = model_config.hf_config config: KimiVLConfig = model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert isinstance(config.vision_config, MoonViTConfig) assert isinstance(config.vision_config, MoonViTConfig)
self.use_data_parallel = ( self.use_data_parallel = (
model_config.multimodal_config.mm_encoder_tp_mode == "data" model_config.multimodal_config.mm_encoder_tp_mode == "data"
) )
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = KimiVLMultiModalProjector( with self._mark_tower_model(vllm_config, "image"):
config=config, self.vision_tower = MoonVitPretrainedModel(
use_data_parallel=self.use_data_parallel, config.vision_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"), multimodal_config=model_config.multimodal_config,
) prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.quant_config = quant_config with self._mark_language_model(vllm_config):
sub_vllm_config = copy.deepcopy(vllm_config) self.language_model = init_vllm_registered_model(
sub_vllm_config.model_config.hf_config = ( vllm_config=vllm_config,
sub_vllm_config.model_config.hf_config.text_config hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
self.language_model = DeepseekV2Model( architectures=["DeepseekV2ForCausalLM"],
vllm_config=sub_vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.text_config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
) )
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id self.media_placeholder: int = self.config.media_placeholder_token_id
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
...@@ -378,8 +359,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -378,8 +359,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# perform vt on processored pixel_values # perform vt on processored pixel_values
@torch.inference_mode() @torch.inference_mode()
def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor: def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"] image_grid_hws = inputs["image_grid_hws"]
if self.use_data_parallel: if self.use_data_parallel:
...@@ -399,9 +378,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -399,9 +378,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
lengths = [x.shape[0] for x in image_features] lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector(torch.cat(image_features)).split(lengths) return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None: def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
# Validate the multimodal input keyword arguments # Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
...@@ -433,145 +409,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -433,145 +409,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) return self.language_model.compute_logits(hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config loader = AutoWeightsLoader(self)
_KEYS_TO_MODIFY_MAPPING = { return loader.load_weights(weights)
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
use_mha = (
config.model_type == "deepseek"
or config.qk_nope_head_dim + config.qk_rope_head_dim == 0
)
if use_mha:
stacked_params_mapping += [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
if getattr(config, "n_routed_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for idx, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
def get_spec_layer_idx_from_weight_name(
config: DeepseekV2Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
...@@ -546,38 +546,37 @@ class Lfm2VLForConditionalGeneration( ...@@ -546,38 +546,37 @@ class Lfm2VLForConditionalGeneration(
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if vision_config.model_type == "siglip2_vision_model": with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = Siglip2Model( if vision_config.model_type == "siglip2_vision_model":
config=vision_config, self.vision_tower = Siglip2Model(
quant_config=quant_config, config=vision_config,
multimodal_config=multimodal_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"), multimodal_config=multimodal_config,
) prefix=maybe_prefix(prefix, "vision_tower"),
else: )
raise ValueError( else:
f"Unsupported visual tokenizer model_type: {vision_config.model_type}" raise ValueError(
) f"Unsupported visual tokenizer type: {vision_config.model_type}"
)
self.multi_modal_projector = Lfm2VLMultiModalProjector( self.multi_modal_projector = Lfm2VLMultiModalProjector(
config=config, config=config,
use_data_parallel=self.use_data_parallel, use_data_parallel=self.use_data_parallel,
prefix=f"{prefix}.multi_modal_projector", prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language"), hf_config=config.text_config,
architectures=config.text_config.architectures, prefix=maybe_prefix(prefix, "language"),
) architectures=config.text_config.architectures,
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> LFM2VLImageInputs | None: ) -> LFM2VLImageInputs | None:
...@@ -714,8 +713,7 @@ class Lfm2VLForConditionalGeneration( ...@@ -714,8 +713,7 @@ class Lfm2VLForConditionalGeneration(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -268,27 +268,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -268,27 +268,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.image_newline = nn.Parameter(
self.multi_modal_projector = LlavaMultiModalProjector( torch.empty(config.text_config.hidden_size)
vision_hidden_size=vision_hidden_size, )
text_hidden_size=config.text_config.hidden_size, self.multi_modal_projector = LlavaMultiModalProjector(
projector_hidden_act=config.projector_hidden_act, vision_hidden_size=vision_hidden_size,
multimodal_projector_bias=config.multimodal_projector_bias, text_hidden_size=config.text_config.hidden_size,
) projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -427,8 +430,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -427,8 +430,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self, self,
inputs: LlavaNextImagePixelInputs, inputs: LlavaNextImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor): if isinstance(pixel_values, torch.Tensor):
...@@ -480,9 +481,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -480,9 +481,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -312,12 +312,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -312,12 +312,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
if modality.startswith("video"): if modality.startswith("video"):
return "<video>" return "<video>"
raise ValueError("Only image or video modality is supported") raise ValueError("Only video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
...@@ -329,26 +327,29 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -329,26 +327,29 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer with self._mark_tower_model(vllm_config, "video"):
self.vision_tower = init_vision_tower_for_llava( # Initialize the vision tower only up to the required feature layer
config, self.vision_tower = init_vision_tower_for_llava(
quant_config=quant_config, config,
multimodal_config=multimodal_config, quant_config=quant_config,
require_post_norm=False, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), require_post_norm=False,
) prefix=maybe_prefix(prefix, "vision_tower"),
self.vision_resampler = LlavaNextVideoPooler(config) )
self.multi_modal_projector = LlavaNextMultiModalProjector( self.vision_resampler = LlavaNextVideoPooler(config)
vision_hidden_size=config.vision_config.hidden_size, self.multi_modal_projector = LlavaNextMultiModalProjector(
text_hidden_size=config.text_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
projector_hidden_act=config.projector_hidden_act, text_hidden_size=config.text_config.hidden_size,
multimodal_projector_bias=config.multimodal_projector_bias, projector_hidden_act=config.projector_hidden_act,
) multimodal_projector_bias=config.multimodal_projector_bias,
self.language_model = init_vllm_registered_model( )
vllm_config=vllm_config,
hf_config=config.text_config, with self._mark_language_model(vllm_config):
prefix=maybe_prefix(prefix, "language_model"), self.language_model = init_vllm_registered_model(
) vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors self.language_model.model.make_empty_intermediate_tensors
...@@ -395,8 +396,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -395,8 +396,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return image_features return image_features
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["pixel_values_videos"] video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor): if isinstance(video_pixels, torch.Tensor):
...@@ -419,9 +418,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -419,9 +418,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return [e.flatten(0, 1) for e in embeds] return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None: if video_input is None:
......
...@@ -508,21 +508,26 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -508,21 +508,26 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = init_vision_tower_for_llava( # Initialize the vision tower only up to the required feature layer
config, self.vision_tower = init_vision_tower_for_llava(
quant_config=quant_config, config,
multimodal_config=multimodal_config, quant_config=quant_config,
require_post_norm=False, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), require_post_norm=False,
) prefix=maybe_prefix(prefix, "vision_tower"),
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) )
self.language_model = init_vllm_registered_model( self.image_newline = nn.Parameter(
vllm_config=vllm_config, torch.empty(config.text_config.hidden_size)
hf_config=config.text_config, )
prefix=maybe_prefix(prefix, "language_model"), self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors self.language_model.model.make_empty_intermediate_tensors
...@@ -726,8 +731,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -726,8 +731,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self, self,
inputs: LlavaOnevisionImagePixelInputs, inputs: LlavaOnevisionImagePixelInputs,
) -> torch.Tensor | list[torch.Tensor]: ) -> torch.Tensor | list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor): if isinstance(pixel_values, torch.Tensor):
...@@ -801,8 +804,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -801,8 +804,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return video_features return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs): def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["pixel_values_videos"] video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor): if isinstance(video_pixels, torch.Tensor):
...@@ -862,9 +863,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -862,9 +863,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
image_feature = image_feature.view(batch_frames, -1, dim) image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment