Unverified Commit 193069d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[5/N] Initialize MM components in context managers (Q-Z) (#32695)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f0feb1cf
...@@ -334,20 +334,21 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -334,20 +334,21 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = init_vllm_registered_model( with self._mark_tower_model(vllm_config, "audio"):
vllm_config=vllm_config, self.audio_tower = Qwen2AudioEncoder(config.audio_config)
hf_config=config.text_config, self.multi_modal_projector = Qwen2AudioMultiModalProjector(
prefix=maybe_prefix(prefix, "language_model"), config.audio_config.d_model, config.text_config.hidden_size
architectures=["Qwen2ForCausalLM"], )
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -441,9 +442,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -441,9 +442,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
masked_audio_features, audio_output_lengths.flatten().tolist() masked_audio_features, audio_output_lengths.flatten().tolist()
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
......
...@@ -1612,32 +1612,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1612,32 +1612,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config self.config = thinker_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = Qwen3MoeLLMForCausalLM( with self._mark_tower_model(vllm_config, "audio"):
vllm_config=vllm_config.with_hf_config( self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"] thinker_config.audio_config,
), multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "audio_tower"),
) )
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr( self.use_deepstack = hasattr(
thinker_config.vision_config, "deepstack_visual_indexes" thinker_config.vision_config, "deepstack_visual_indexes"
...@@ -1647,22 +1629,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1647,22 +1629,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if self.use_deepstack if self.use_deepstack
else 0 else 0
) )
# register buffer for deepstack
self.deepstack_input_embeds = (
[
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
if self.use_deepstack
else None
)
self.visual_dim = thinker_config.vision_config.out_hidden_size self.visual_dim = thinker_config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config,
architectures=["Qwen3MoeForCausalLM"],
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer # get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors( return IntermediateTensors(
{ {
...@@ -1674,6 +1682,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1674,6 +1682,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer # set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1) num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0): if num_tokens > self.deepstack_input_embeds[0].size(0):
...@@ -1692,6 +1703,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1692,6 +1703,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer # clear deepstack_input_embeds in buffer
if num_tokens > 0: if num_tokens > 0:
for idx in range(self.deepstack_num_level): for idx in range(self.deepstack_num_level):
...@@ -1726,9 +1740,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1726,9 +1740,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
...@@ -1844,11 +1855,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1844,11 +1855,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
if ( if inputs_embeds is not None and get_pp_group().is_first_rank:
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
deepstack_input_embeds = self._get_deepstack_input_embeds( deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0) inputs_embeds.size(0)
) )
......
...@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration( ...@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration(
num_layers = len(self.language_model.model.layers) num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer # get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors( return IntermediateTensors(
{ {
...@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration( ...@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration(
) )
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer # set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1) num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0): if num_tokens > self.deepstack_input_embeds[0].size(0):
...@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration( ...@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration(
) )
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer # clear deepstack_input_embeds in buffer
if num_tokens > 0: if num_tokens > 0:
for idx in range(self.deepstack_num_level): for idx in range(self.deepstack_num_level):
...@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration(
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
if ( if inputs_embeds is not None and get_pp_group().is_first_rank:
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
deepstack_input_embeds = self._get_deepstack_input_embeds( deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0) inputs_embeds.size(0)
) )
......
...@@ -620,7 +620,6 @@ class RadioInternVisionModel(nn.Module): ...@@ -620,7 +620,6 @@ class RadioInternVisionModel(nn.Module):
x: torch.Tensor, x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None, imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes) hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1: if imgs_sizes is not None and len(imgs_sizes) > 1:
......
...@@ -1033,20 +1033,22 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1033,20 +1033,22 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.text_embed_dim = text_config.hidden_size self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size self.vision_embed_dim = vision_config.hidden_size
self.text_projection_size = text_config.projection_size
self.text_model = SiglipTextTransformer( with self._mark_language_model(vllm_config):
text_config, self.text_model = SiglipTextTransformer(
quant_config=quant_config, text_config,
prefix=maybe_prefix(prefix, "text_model"), quant_config=quant_config,
) prefix=maybe_prefix(prefix, "text_model"),
self.vision_model = SiglipVisionTransformer( )
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.text_projection_size = text_config.projection_size with self._mark_tower_model(vllm_config, "image"):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
...@@ -1155,9 +1157,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1155,9 +1157,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
return self.get_image_features(pixel_values) return self.get_image_features(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids( def _embed_text_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -674,24 +674,26 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -674,24 +674,26 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0] llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM" self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model( with self._mark_tower_model(vllm_config, "image"):
vllm_config=vllm_config, self.vision_model = self._init_vision_model(
hf_config=config.text_config, config,
prefix=maybe_prefix(prefix, "language_model"), quant_config=quant_config,
) is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
self.mlp1 = self._init_mlp1( with self._mark_language_model(vllm_config):
config, quant_config, prefix=maybe_prefix(prefix, "mlp1") self.language_model = init_vllm_registered_model(
) vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None self.img_context_token_id = None
self.visual_token_mask = None self.visual_token_mask = None
...@@ -838,8 +840,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -838,8 +840,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"]) image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -867,9 +867,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -867,9 +867,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -423,38 +423,43 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -423,38 +423,43 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config self.config = config # Storing the Tarsier-specific HF config
self.vision_tower = init_vision_tower_for_tarsier(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
projector_bias = getattr(config, "multimodal_projector_bias", True)
self.multi_modal_projector = TarsierMultiModalProjector( with self._mark_tower_model(vllm_config, "image"):
vision_hidden_size=config.vision_config.hidden_size, self.vision_tower = init_vision_tower_for_tarsier(
text_hidden_size=config.text_config.hidden_size, config,
projector_hidden_act=config.projector_hidden_act, quant_config=quant_config,
multimodal_projector_bias=projector_bias, multimodal_config=multimodal_config,
quant_config=quant_config, require_post_norm=False,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.language_model = init_vllm_registered_model( projector_bias = getattr(config, "multimodal_projector_bias", True)
vllm_config=vllm_config,
hf_config=config.text_config, # Use text_config from Tarsier's main config self.multi_modal_projector = TarsierMultiModalProjector(
prefix=maybe_prefix(prefix, "language_model"), vision_hidden_size=config.vision_config.hidden_size,
) text_hidden_size=config.text_config.hidden_size,
self.register_buffer( projector_hidden_act=config.projector_hidden_act,
"image_newline_idx_tensor", multimodal_projector_bias=projector_bias,
torch.tensor([config.image_newline_idx], dtype=torch.long), quant_config=quant_config,
persistent=False, prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.register_buffer( self.register_buffer(
"image_new_idx_tensor", "image_newline_idx_tensor",
torch.tensor([config.image_new_idx], dtype=torch.long), torch.tensor([config.image_newline_idx], dtype=torch.long),
persistent=False, persistent=False,
) )
self.register_buffer(
"image_new_idx_tensor",
torch.tensor([config.image_new_idx], dtype=torch.long),
persistent=False,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# Use text_config from Tarsier's main config
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -547,7 +552,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -547,7 +552,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self, self,
inputs: TarsierImagePixelInputs, inputs: TarsierImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
image_features_selected = self._image_pixels_to_features( image_features_selected = self._image_pixels_to_features(
self.vision_tower, pixel_values self.vision_tower, pixel_values
...@@ -575,11 +579,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -575,11 +579,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
"Incorrect type of image_embeds. " "Incorrect type of image_embeds. "
f"Got type: {type(projected_features)}. " f"Got type: {type(projected_features)}. "
) )
assert self.vision_tower is not None
return self._process_image_pixels(image_input)
def get_language_model(self) -> torch.nn.Module: return self._process_image_pixels(image_input)
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -543,7 +543,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -543,7 +543,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
assert self.multi_modal_config assert self.multi_modal_config
self.secondary_weights = [] self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None: if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights # this prefix is not for initialization, but for loading weights
# note the trailing dot # note the trailing dot
...@@ -554,15 +553,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -554,15 +553,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.", prefix="audio_tower.",
) )
) )
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.text_model_id is not None: if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights # this prefix is not for initialization, but for loading weights
# note the trailing dot # note the trailing dot
...@@ -574,6 +564,20 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -574,6 +564,20 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
) )
) )
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
...@@ -681,9 +685,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -681,9 +685,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
] ]
return flattened_embeddings.split(embed_lens) return flattened_embeddings.split(embed_lens)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
......
...@@ -366,22 +366,22 @@ class VoxtralForConditionalGeneration( ...@@ -366,22 +366,22 @@ class VoxtralForConditionalGeneration(
self.config = config self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor self.downsample_factor = self.config.audio_config.downsample_factor
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
self.whisper_encoder = VoxtralEncoderModel( )
vllm_config.with_hf_config(config.audio_config),
prefix=maybe_prefix(prefix, "whisper_encoder"),
)
self.audio_language_adapter = AudioLanguageAdapter(
hidden_size=config.audio_config.d_model * self.downsample_factor,
dim=config.text_config.hidden_size,
)
def get_language_model(self) -> torch.nn.Module: with self._mark_tower_model(vllm_config, "audio"):
return self.language_model self.whisper_encoder = VoxtralEncoderModel(
vllm_config.with_hf_config(config.audio_config),
prefix=maybe_prefix(prefix, "whisper_encoder"),
)
self.audio_language_adapter = AudioLanguageAdapter(
hidden_size=config.audio_config.d_model * self.downsample_factor,
dim=config.text_config.hidden_size,
)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules.""" """Get module prefix for multimodal models to filter LoRA modules."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment