Unverified Commit 62965de5 authored by Farzad Abdolhosseini's avatar Farzad Abdolhosseini Committed by GitHub
Browse files

[Model] Ultravox: Support Llama 4 and Gemma 3 backends (#17818)


Signed-off-by: default avatarFarzad Abdolhosseini <farzad@fixie.ai>
Signed-off-by: default avatarPatrick Li <patrick8289@gmail.com>
Co-authored-by: default avatarPatrick Li <patrick8289@gmail.com>
parent 7ae75fa6
...@@ -221,6 +221,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -221,6 +221,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False), is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
......
...@@ -89,6 +89,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -89,6 +89,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
......
...@@ -39,9 +39,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -39,9 +39,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings, merge_multimodal_embeddings,
merge_multimodal_embeddings_from_map) merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
_MAX_ENCODER_BATCH_SIZE = 16 _MAX_ENCODER_BATCH_SIZE = 16
...@@ -80,14 +78,15 @@ class UltravoxProcessingInfo(BaseProcessingInfo): ...@@ -80,14 +78,15 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
sampling_rate: Optional[int] = None, sampling_rate: Optional[int] = None,
**kwargs: object, **kwargs: object,
) -> ProcessorMixin: ) -> ProcessorMixin:
config = self.ctx.model_config.hf_config
hf_processor = self.ctx.get_hf_processor(**kwargs) hf_processor = self.ctx.get_hf_processor(**kwargs)
# NOTE: Ultravox processing definition uses '<|eot_id|>' as the # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
# placeholder that will cause confusion with the actual end of turn # placeholder that will cause confusion with the actual end of turn
# token, thus we override placeholder with a reserved special # token, thus we override placeholder with a reserved token.
# token.
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN hf_processor.audio_replacement_token_id = config.audio_token_index
return hf_processor return hf_processor
def get_feature_extractor( def get_feature_extractor(
...@@ -274,7 +273,7 @@ class UltravoxProjector(nn.Module): ...@@ -274,7 +273,7 @@ class UltravoxProjector(nn.Module):
else: else:
self.act = get_act_fn(config.projector_act) self.act = get_act_fn(config.projector_act)
dim_out = config.text_config.hidden_size dim_out = config.text_hidden_size
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
# Ultravox v0.4.1 and below use layer_norm after the second linear layer # Ultravox v0.4.1 and below use layer_norm after the second linear layer
...@@ -572,9 +571,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -572,9 +571,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) # The audio token index is not included in the embedding table
if multimodal_embeddings is not None \ # We need to remove it before embedding lookup
and len(multimodal_embeddings) != 0: safe_input_ids = input_ids.clone()
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
inputs_embeds = self.language_model.get_input_embeddings(
safe_input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
# TODO(ywang96): remove this block after v0 is deprecated. # TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
...@@ -585,7 +589,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -585,7 +589,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
else: else:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
_AUDIO_PLACEHOLDER_TOKEN) self.config.audio_token_index)
return inputs_embeds return inputs_embeds
def forward(self, def forward(self,
...@@ -623,10 +627,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -623,10 +627,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings) multimodal_embeddings)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, language_model = self.language_model
positions, if hasattr(language_model, "language_model"):
intermediate_tensors, language_model = language_model.language_model
inputs_embeds=inputs_embeds)
hidden_states = language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -45,6 +45,7 @@ class UltravoxConfig(transformers.PretrainedConfig): ...@@ -45,6 +45,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
""" """
model_type = "ultravox" model_type = "ultravox"
audio_token = "<|audio|>"
is_composition = False is_composition = False
def __init__( def __init__(
...@@ -80,29 +81,32 @@ class UltravoxConfig(transformers.PretrainedConfig): ...@@ -80,29 +81,32 @@ class UltravoxConfig(transformers.PretrainedConfig):
# Avoid circular import # Avoid circular import
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
self.text_config = get_config(text_model_id, text_config_obj = get_config(text_model_id,
trust_remote_code=False) trust_remote_code=False)
else: else:
text_config = text_config or {} text_config = text_config or {}
self.text_config = transformers.CONFIG_MAPPING[text_config.get( text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
"model_type", "llama")](**text_config) "model_type", "llama")](**text_config)
inner_text_config = text_config_obj.get_text_config()
if audio_model_id is not None: if audio_model_id is not None:
# Avoid circular import # Avoid circular import
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
self.audio_config = get_config(audio_model_id, audio_config = get_config(audio_model_id, trust_remote_code=False)
trust_remote_code=False)
else: else:
audio_config = audio_config or {} audio_config = audio_config or {}
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( audio_config = transformers.CONFIG_MAPPING[audio_config.get(
"model_type", "whisper")](**audio_config) "model_type", "whisper")](**audio_config)
self.text_config = text_config_obj
self.audio_config = audio_config
self.text_model_lora_config = text_model_lora_config or {} self.text_model_lora_config = text_model_lora_config or {}
self.audio_model_lora_config = audio_model_lora_config or {} self.audio_model_lora_config = audio_model_lora_config or {}
self.vocab_size = self.text_config.vocab_size self.vocab_size = inner_text_config.vocab_size
self.initializer_range = inner_text_config.initializer_range
self.initializer_range = self.text_config.initializer_range self.text_hidden_size = inner_text_config.hidden_size
super().__init__(**kwargs) super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment