"vscode:/vscode.git/clone" did not exist on "09e4576f65b751fc682983a296e246f239979558"
Unverified Commit b75e85de authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Initialize MM components in context managers (A-D) (#32632)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4753f3bf
...@@ -15,9 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_rank ...@@ -15,9 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -539,6 +537,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -539,6 +537,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = AriaVisionTransformer( self.vision_tower = AriaVisionTransformer(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -547,22 +547,12 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -547,22 +547,12 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.multi_modal_projector = AriaProjector( self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector") config, prefix=maybe_prefix(prefix, "multi_modal_projector")
) )
self.vocab_size = config.text_config.vocab_size
with self._mark_language_model(vllm_config):
self.language_model = AriaTextModel( self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config), vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"), prefix=maybe_prefix(prefix, "language_model.model"),
) )
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.lm_head = ParallelLMHead(
self.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
...@@ -618,9 +608,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -618,9 +608,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask) return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -654,9 +641,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -654,9 +641,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: def compute_logits(
logits = self.logits_processor(self.lm_head, hidden_states) self,
return logits hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -460,14 +460,15 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -460,14 +460,15 @@ class AudioFlamingo3ForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = AudioFlamingo3Encoder( self.audio_tower = AudioFlamingo3Encoder(
config.audio_config, config.audio_config,
) )
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
self.quant_config = quant_config with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
...@@ -599,9 +600,6 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -599,9 +600,6 @@ class AudioFlamingo3ForConditionalGeneration(
current_idx += count current_idx += count
return tuple(grouped_embeddings) return tuple(grouped_embeddings)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
......
...@@ -343,14 +343,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -343,14 +343,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel( self.vision_tower = SiglipVisionModel(
config.vision_config, config.vision_config,
quant_config, quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = AyaVisionMultiModalProjector(config) self.multi_modal_projector = AyaVisionMultiModalProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
...@@ -410,9 +412,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -410,9 +412,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
}, },
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -44,6 +44,7 @@ from .interfaces import ( ...@@ -44,6 +44,7 @@ from .interfaces import (
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
TowerMissingLayer,
) )
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
...@@ -373,6 +374,7 @@ class BagelForConditionalGeneration( ...@@ -373,6 +374,7 @@ class BagelForConditionalGeneration(
# Initialize language model (Qwen2) # Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly # Pass the llm_config from BagelConfig to initialize Qwen2 properly
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.llm_config, hf_config=config.llm_config,
...@@ -398,6 +400,7 @@ class BagelForConditionalGeneration( ...@@ -398,6 +400,7 @@ class BagelForConditionalGeneration(
) )
vit_config.vision_use_head = False vit_config.vision_use_head = False
with self._mark_tower_model(vllm_config, "image"):
self.vit_model = SiglipVisionModel( self.vit_model = SiglipVisionModel(
config=vit_config, config=vit_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -423,9 +426,9 @@ class BagelForConditionalGeneration( ...@@ -423,9 +426,9 @@ class BagelForConditionalGeneration(
hidden_size=llm_hidden_size, hidden_size=llm_hidden_size,
) )
else: else:
self.vit_model = None self.vit_model = TowerMissingLayer("image")
self.connector = None self.connector = TowerMissingLayer("image")
self.vit_pos_embed = None self.vit_pos_embed = TowerMissingLayer("image")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -502,9 +505,6 @@ class BagelForConditionalGeneration( ...@@ -502,9 +505,6 @@ class BagelForConditionalGeneration(
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_language_model(self) -> nn.Module:
return self.language_model
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -540,14 +540,6 @@ class BagelForConditionalGeneration( ...@@ -540,14 +540,6 @@ class BagelForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from checkpoint.""" """Load weights from checkpoint."""
skip_prefixes = []
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes.append("vit_pos_embed.pos_embed")
# If visual understanding is disabled, skip vision-related weights
if self.vit_model is None:
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])
# Skip generation-related weights since we only support text2text and image2text # Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components: # Filter out all image generation components:
# - 'moe_gen': MoE generation weights # - 'moe_gen': MoE generation weights
...@@ -587,5 +579,6 @@ class BagelForConditionalGeneration( ...@@ -587,5 +579,6 @@ class BagelForConditionalGeneration(
filtered_weights.append((name, tensor)) filtered_weights.append((name, tensor))
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) # Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
loader = AutoWeightsLoader(self, skip_prefixes=["vit_pos_embed.pos_embed"])
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)
...@@ -549,26 +549,26 @@ class Blip2ForConditionalGeneration( ...@@ -549,26 +549,26 @@ class Blip2ForConditionalGeneration(
+ 1 # include class token + 1 # include class token
) )
# TODO: Optionally initializes this for supporting embeddings. with self._mark_tower_model(vllm_config, "image"):
self.vision_model = BlipVisionModel(vision_config, quant_config) self.vision_model = BlipVisionModel(vision_config, quant_config)
self.query_tokens = nn.Parameter( self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size) torch.zeros(
1, config.num_query_tokens, config.qformer_config.hidden_size
)
) )
self.qformer = Blip2QFormerModel( self.qformer = Blip2QFormerModel(
config.qformer_config, config.qformer_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qformer", prefix=f"{prefix}.qformer",
) )
self.language_projection = nn.Linear( self.language_projection = nn.Linear(
config.qformer_config.hidden_size, config.qformer_config.hidden_size,
config.text_config.hidden_size, config.text_config.hidden_size,
bias=True, bias=True,
) )
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
...@@ -614,8 +614,6 @@ class Blip2ForConditionalGeneration( ...@@ -614,8 +614,6 @@ class Blip2ForConditionalGeneration(
return image_features return image_features
def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor: def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"] pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values) return self._image_pixels_to_features(self.vision_model, pixel_values)
...@@ -624,7 +622,6 @@ class Blip2ForConditionalGeneration( ...@@ -624,7 +622,6 @@ class Blip2ForConditionalGeneration(
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1)
...@@ -635,9 +632,6 @@ class Blip2ForConditionalGeneration( ...@@ -635,9 +632,6 @@ class Blip2ForConditionalGeneration(
return self.language_projection(query_output) return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -853,28 +853,30 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -853,28 +853,30 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.text_embed_dim = text_config.hidden_size self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size self.vision_embed_dim = vision_config.hidden_size
with self._mark_language_model(vllm_config):
self.text_model = CLIPTextTransformer( self.text_model = CLIPTextTransformer(
text_config, text_config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"), prefix=maybe_prefix(prefix, "text_model"),
) )
self.text_projection = nn.Linear(
self.text_embed_dim,
self.projection_dim,
bias=False,
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.visual_projection = nn.Linear( self.visual_projection = nn.Linear(
self.vision_embed_dim, self.vision_embed_dim,
self.projection_dim, self.projection_dim,
bias=False, bias=False,
) )
self.text_projection = nn.Linear(
self.text_embed_dim,
self.projection_dim,
bias=False,
)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
...@@ -940,9 +942,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -940,9 +942,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
return self.get_image_features(pixel_values) return self.get_image_features(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids( def _embed_text_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -353,15 +353,17 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo ...@@ -353,15 +353,17 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config) self._patch_quant_config(config, quant_config)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel( self.vision_tower = SiglipVisionModel(
config.vision_config, config.vision_config,
quant_config, quant_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = Cohere2VisionMultiModalProjector( self.multi_modal_projector = Cohere2VisionMultiModalProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector") config, prefix=maybe_prefix(prefix, "multi_modal_projector")
) )
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
...@@ -437,9 +439,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo ...@@ -437,9 +439,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
): ):
quant_config.modules_to_not_convert.append("vision_tower") quant_config.modules_to_not_convert.append("vision_tower")
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -383,6 +383,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -383,6 +383,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
with self._mark_tower_model(vllm_config, "image"):
self.sam_model = build_sam_vit_b() self.sam_model = build_sam_vit_b()
clip_vision_config = CLIPVisionConfig( clip_vision_config = CLIPVisionConfig(
hidden_size=1024, hidden_size=1024,
...@@ -418,6 +419,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -418,6 +419,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
) )
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=self.text_config, hf_config=self.text_config,
...@@ -552,9 +554,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -552,9 +554,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
return vision_features return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -374,6 +374,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -374,6 +374,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
with self._mark_tower_model(vllm_config, "image"):
self.vision = self._init_vision_module( self.vision = self._init_vision_module(
self.vision_config, quant_config, maybe_prefix(prefix, "vision") self.vision_config, quant_config, maybe_prefix(prefix, "vision")
) )
...@@ -400,6 +401,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -400,6 +401,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
) )
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=self.text_config, hf_config=self.text_config,
...@@ -603,9 +605,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -603,9 +605,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop pixel_values=pixel_values, images_spatial_crop=images_spatial_crop
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -689,12 +689,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -689,12 +689,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
else: else:
vision_config = self.config.vision_config vision_config = self.config.vision_config
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = DotsVisionTransformer( self.vision_tower = DotsVisionTransformer(
vision_config, vision_config,
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
with self._mark_language_model(vllm_config):
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=self.config, hf_config=self.config,
...@@ -763,9 +766,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -763,9 +766,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
return image_embeds.split(sizes) return image_embeds.split(sizes)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
merge_size = self.vision_tower.spatial_merge_size merge_size = self.vision_tower.spatial_merge_size
return num_image_tokens * (merge_size**2) return num_image_tokens * (merge_size**2)
......
...@@ -83,7 +83,10 @@ class LMMissingLayer(nn.Module): ...@@ -83,7 +83,10 @@ class LMMissingLayer(nn.Module):
class TowerMissingLayer(nn.Module): class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {} packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str]) -> None: def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str):
modalities = {modalities}
super().__init__() super().__init__()
self.modalities = modalities self.modalities = modalities
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment