Unverified Commit e1a34c3a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[2/N] Initialize MM components in context managers (E-H) (#32641)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 148117ea
...@@ -590,8 +590,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -590,8 +590,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def _process_image_input( def _process_image_input(
self, image_input: AriaImagePixelInputs self, image_input: AriaImagePixelInputs
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
pixel_mask = image_input["pixel_mask"] pixel_mask = image_input["pixel_mask"]
......
...@@ -382,7 +382,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -382,7 +382,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def _process_image_input( def _process_image_input(
self, image_input: AyaVisionImagePixelInputs, **kwargs self, image_input: AyaVisionImagePixelInputs, **kwargs
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features( image_features = self._image_pixels_to_features(
......
...@@ -391,8 +391,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo ...@@ -391,8 +391,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
Returns: Returns:
List of flattened image embeddings, one per image List of flattened image embeddings, one per image
""" """
assert self.vision_tower is not None, "Vision tower is required"
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
......
...@@ -1303,27 +1303,28 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1303,27 +1303,28 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_model = Ernie4_5_VisionTransformer( with self._mark_tower_model(vllm_config, {"image", "video"}):
config.vision_config, self.vision_model = Ernie4_5_VisionTransformer(
norm_eps=getattr(config, "rms_norm_eps", 1e-6), config.vision_config,
quant_config=quant_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6),
multimodal_config=multimodal_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"), multimodal_config=multimodal_config,
) prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = Ernie4_5_VLMoeForCausalLM( self.resampler_model = VariableResolutionResamplerModel(
vllm_config=vllm_config, self.config.pixel_hidden_size,
prefix=maybe_prefix(prefix, "language_model"), self.config.hidden_size,
) self.config.spatial_conv_size,
self.config.temporal_conv_size,
config=self.config,
prefix=maybe_prefix(prefix, "resampler_model"),
)
self.resampler_model = VariableResolutionResamplerModel( with self._mark_language_model(vllm_config):
self.config.pixel_hidden_size, self.language_model = Ernie4_5_VLMoeForCausalLM(
self.config.hidden_size, vllm_config=vllm_config,
self.config.spatial_conv_size, prefix=maybe_prefix(prefix, "language_model"),
self.config.temporal_conv_size, )
config=self.config,
prefix=maybe_prefix(prefix, "resampler_model"),
)
self.visual_token_mask = None self.visual_token_mask = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
...@@ -1522,9 +1523,6 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1522,9 +1523,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Ernie4_5_VLImageInputs | None: ) -> Ernie4_5_VLImageInputs | None:
......
...@@ -287,16 +287,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -287,16 +287,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels self.image_feature_size = config.patch_size**2 * config.num_channels
self.vision_embed_tokens = ColumnParallelLinear( with self._mark_tower_model(vllm_config, "image"):
self.image_feature_size, self.vision_embed_tokens = ColumnParallelLinear(
config.hidden_size, self.image_feature_size,
quant_config=quant_config, config.hidden_size,
gather_output=True, quant_config=quant_config,
) gather_output=True,
self.language_model = PersimmonForCausalLM( )
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"), with self._mark_language_model(vllm_config):
) self.language_model = PersimmonForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
...@@ -323,14 +327,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -323,14 +327,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat = image_input["image_patches_flat"] image_patches_flat = image_input["image_patches_flat"]
patches_per_image = image_input["patches_per_image"] patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat) vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0) return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
...@@ -361,10 +361,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -361,10 +361,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.language_model.logits_processor( return self.language_model.compute_logits(hidden_states)
self.language_model.lm_head, hidden_states
)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -522,25 +522,27 @@ class Gemma3ForConditionalGeneration( ...@@ -522,25 +522,27 @@ class Gemma3ForConditionalGeneration(
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel( with self._mark_tower_model(vllm_config, "image"):
config.vision_config, self.vision_tower = SiglipVisionModel(
quant_config, config.vision_config,
prefix=maybe_prefix(prefix, "vision_tower"), quant_config,
) prefix=maybe_prefix(prefix, "vision_tower"),
self.multi_modal_projector = Gemma3MultiModalProjector(config) )
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, with self._mark_language_model(vllm_config):
hf_config=config.text_config, self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
architectures=["Gemma3ForCausalLM"], hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
logit_scale = getattr(config, "logit_scale", 1.0) architectures=["Gemma3ForCausalLM"],
)
if hasattr(self.language_model, "logits_processor"): logit_scale = getattr(config, "logit_scale", 1.0)
# The logits processor can be unset if we're using if hasattr(self.language_model, "logits_processor"):
# automatic conversion to pooling model. # The logits processor can be unset if we're using
self.language_model.logits_processor.scale *= logit_scale # automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -579,8 +581,6 @@ class Gemma3ForConditionalGeneration( ...@@ -579,8 +581,6 @@ class Gemma3ForConditionalGeneration(
self, self,
image_input: Gemma3ImageInputs, image_input: Gemma3ImageInputs,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -592,9 +592,6 @@ class Gemma3ForConditionalGeneration( ...@@ -592,9 +592,6 @@ class Gemma3ForConditionalGeneration(
return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())] return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -503,31 +503,35 @@ class Gemma3nForConditionalGeneration( ...@@ -503,31 +503,35 @@ class Gemma3nForConditionalGeneration(
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.vision_tower = AutoModel.from_config(config=config.vision_config) with self._mark_tower_model(vllm_config, "image"):
self.audio_tower = AutoModel.from_config(config=config.audio_config) self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma3nMultimodalEmbedder( self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config, config.text_config config.vision_config, config.text_config
) )
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config, config.text_config
)
self.language_model: nn.Module = init_vllm_registered_model( with self._mark_tower_model(vllm_config, "audio"):
vllm_config=vllm_config, self.audio_tower = AutoModel.from_config(config=config.audio_config)
hf_config=config.text_config, self.embed_audio = Gemma3nMultimodalEmbedder(
prefix=maybe_prefix(prefix, "language_model"), config.audio_config, config.text_config
architectures=["Gemma3nForCausalLM"], )
)
self.language_model = cast(Gemma3nForCausalLM, self.language_model) with self._mark_language_model(vllm_config):
# NOTE (NickLucche) In order to be compatible with cudagraph, the self.language_model: Gemma3nForCausalLM = init_vllm_registered_model(
# buffer needs to be consistent, so we pre-allocate here. vllm_config=vllm_config,
self.per_layer_embeddings = torch.zeros( hf_config=config.text_config,
vllm_config.scheduler_config.max_num_batched_tokens, prefix=maybe_prefix(prefix, "language_model"),
self.config.text_config.num_hidden_layers, architectures=["Gemma3nForCausalLM"],
self.config.text_config.hidden_size_per_layer_input, )
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype, # NOTE (NickLucche) In order to be compatible with cudagraph, the
) # buffer needs to be consistent, so we pre-allocate here.
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype,
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
...@@ -583,8 +587,6 @@ class Gemma3nForConditionalGeneration( ...@@ -583,8 +587,6 @@ class Gemma3nForConditionalGeneration(
self, self,
image_input: Gemma3nImageInputs, image_input: Gemma3nImageInputs,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower( vision_outputs = self.vision_tower(
pixel_values=pixel_values, do_pooling=False, return_dict=True pixel_values=pixel_values, do_pooling=False, return_dict=True
...@@ -609,7 +611,6 @@ class Gemma3nForConditionalGeneration( ...@@ -609,7 +611,6 @@ class Gemma3nForConditionalGeneration(
self, self,
audio_input: Gemma3nAudioInputs, audio_input: Gemma3nAudioInputs,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
assert self.audio_tower is not None
# Run on padded features to enable batching # Run on padded features to enable batching
input_features = audio_input["input_features_padded"].squeeze(1) input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1) input_features_mask = audio_input["input_features_mask"].squeeze(1)
...@@ -651,9 +652,6 @@ class Gemma3nForConditionalGeneration( ...@@ -651,9 +652,6 @@ class Gemma3nForConditionalGeneration(
# Return a list of embeddings instead of a batched tensor # Return a list of embeddings instead of a batched tensor
return audio_features.unbind(0) return audio_features.unbind(0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if mm_input_by_modality is None: if mm_input_by_modality is None:
......
...@@ -1434,13 +1434,14 @@ class Glm4vForConditionalGeneration( ...@@ -1434,13 +1434,14 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Glm4vVisionTransformer( with self._mark_tower_model(vllm_config, {"image", "video"}):
config.vision_config, self.visual = Glm4vVisionTransformer(
norm_eps=getattr(config, "rms_norm_eps", 1e-5), config.vision_config,
quant_config=quant_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5),
multimodal_config=multimodal_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), multimodal_config=multimodal_config,
) prefix=maybe_prefix(prefix, "visual"),
)
if config.model_type == "glm4v": if config.model_type == "glm4v":
architectures = ["Glm4ForCausalLM"] architectures = ["Glm4ForCausalLM"]
...@@ -1449,12 +1450,13 @@ class Glm4vForConditionalGeneration( ...@@ -1449,12 +1450,13 @@ class Glm4vForConditionalGeneration(
else: else:
architectures = None architectures = None
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
architectures=architectures, prefix=maybe_prefix(prefix, "language_model"),
) architectures=architectures,
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1578,9 +1580,6 @@ class Glm4vForConditionalGeneration( ...@@ -1578,9 +1580,6 @@ class Glm4vForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
......
...@@ -944,26 +944,27 @@ class GlmAsrForConditionalGeneration( ...@@ -944,26 +944,27 @@ class GlmAsrForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Use optimized vLLM native encoder
self.audio_tower = GlmAsrEncoder(
config.audio_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.multi_modal_projector = GlmAsrMultiModalProjector(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = init_vllm_registered_model( with self._mark_tower_model(vllm_config, "audio"):
vllm_config=vllm_config, self.audio_tower = GlmAsrEncoder(
hf_config=config.text_config, config.audio_config,
prefix=maybe_prefix(prefix, "language_model"), quant_config=quant_config,
architectures=["LlamaForCausalLM"], prefix=maybe_prefix(prefix, "audio_tower"),
) )
self.multi_modal_projector = GlmAsrMultiModalProjector(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1063,9 +1064,6 @@ class GlmAsrForConditionalGeneration( ...@@ -1063,9 +1064,6 @@ class GlmAsrForConditionalGeneration(
) )
return _group_audio_embeddings(chunk_embeddings, chunk_counts) return _group_audio_embeddings(chunk_embeddings, chunk_counts)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
......
...@@ -597,27 +597,29 @@ class GraniteSpeechForConditionalGeneration( ...@@ -597,27 +597,29 @@ class GraniteSpeechForConditionalGeneration(
self.quant_config = quant_config self.quant_config = quant_config
self.cache_config = cache_config self.cache_config = cache_config
# The language model is typically a Granite LLM with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( # The language model is typically a Granite LLM
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
# Conformer encoder with self._mark_tower_model(vllm_config, "audio"):
self.encoder = GraniteSpeechCTCEncoder( # Conformer encoder
config=config.encoder_config, self.encoder = GraniteSpeechCTCEncoder(
quant_config=quant_config, config=config.encoder_config,
prefix=f"{prefix}.encoder", quant_config=quant_config,
) prefix=f"{prefix}.encoder",
)
# Blip2 QFormer # Blip2 QFormer
self.projector = GraniteSpeechEncoderProjector( self.projector = GraniteSpeechEncoderProjector(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.projector", prefix=f"{prefix}.projector",
) )
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -770,9 +772,6 @@ class GraniteSpeechForConditionalGeneration( ...@@ -770,9 +772,6 @@ class GraniteSpeechForConditionalGeneration(
# Split variable length features into a tuple # Split variable length features into a tuple
return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
......
...@@ -877,7 +877,7 @@ class HunYuanVLForConditionalGeneration( ...@@ -877,7 +877,7 @@ class HunYuanVLForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"): with self._mark_tower_model(vllm_config, {"image"}):
attn_backend_override = ( attn_backend_override = (
multimodal_config.mm_encoder_attn_backend multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None if multimodal_config is not None
...@@ -890,17 +890,16 @@ class HunYuanVLForConditionalGeneration( ...@@ -890,17 +890,16 @@ class HunYuanVLForConditionalGeneration(
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
else:
self.visual = None with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
self.language_model = init_vllm_registered_model( vllm_config=vllm_config,
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model.model"),
prefix=maybe_prefix(prefix, "language_model.model"), architectures=[
architectures=[ "HunYuanDenseV1ForCausalLM",
"HunYuanDenseV1ForCausalLM", "HunYuanMoEV1ForCausalLM",
"HunYuanMoEV1ForCausalLM", ],
], )
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -970,9 +969,6 @@ class HunYuanVLForConditionalGeneration( ...@@ -970,9 +969,6 @@ class HunYuanVLForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
......
...@@ -15,7 +15,6 @@ from einops import rearrange ...@@ -15,7 +15,6 @@ from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
...@@ -625,8 +624,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -625,8 +624,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config, vision_config config, vision_config
) )
# init models & parameters with self._mark_tower_model(vllm_config, {"image", "video"}):
with no_init_weights(): # weight will be loaded in from_pretrained
self.vision_model = init_vision_tower_for_hcxvision( self.vision_model = init_vision_tower_for_hcxvision(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -635,20 +633,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -635,20 +633,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.mm_projector = self._init_mm_projector(config, text_config, vision_config) self.mm_projector = self._init_mm_projector(
config, text_config, vision_config
)
self.lm_head_vocab_size = getattr( if config.anyres:
text_config, "padded_vocab_size", text_config.vocab_size self.image_newline = nn.Parameter(
) torch.empty(text_config.hidden_size, dtype=self.dtype)
self.language_model = init_vllm_registered_model( )
vllm_config=vllm_config,
hf_config=text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.anyres: with self._mark_language_model(vllm_config):
self.image_newline = nn.Parameter( self.language_model = init_vllm_registered_model(
torch.empty(text_config.hidden_size, dtype=self.dtype) vllm_config=vllm_config,
hf_config=text_config,
prefix=maybe_prefix(prefix, "language_model"),
) )
self.config = config self.config = config
...@@ -726,9 +724,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -726,9 +724,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return modalities return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment