Unverified Commit 868a8c5b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix Ultravox on V1 (#14929)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b4ad56c1
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors ...@@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsV0Only) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings, merge_multimodal_embeddings,
...@@ -50,14 +50,14 @@ _MAX_ENCODER_BATCH_SIZE = 16 ...@@ -50,14 +50,14 @@ _MAX_ENCODER_BATCH_SIZE = 16
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: NestedTensors data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
"""Shape: `(batch_size, num_chunks, 80, M)`""" """Shape: `(batch_size, num_chunks, 80, M)`"""
lens: NestedTensors lens: Union[torch.Tensor, list[torch.Tensor]]
""" """
Length of the audio frames. Used for attention mask in WhisperEncoder. Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size, num_chunks)` Shape: `(batch_size, num_chunks)`
""" """
token_len: NestedTensors token_len: Union[torch.Tensor, list[torch.Tensor]]
""" """
Length of the audio tokens. Used for flattening the audio features. Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size, num_chunks)` Shape: `(batch_size, num_chunks)`
...@@ -405,8 +405,7 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -405,8 +405,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
UltravoxMultiModalProcessor, UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo, info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder) dummy_inputs=UltravoxDummyInputsBuilder)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
SupportsV0Only):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
...@@ -506,6 +505,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -506,6 +505,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if not isinstance(audio_features, (torch.Tensor, list)): if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. " raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}") f"Got type: {type(audio_features)}")
if not isinstance(audio_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_lens. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_token_len, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_token_len. "
f"Got type: {type(audio_features)}")
return UltravoxAudioFeatureInputs(type="audio_features", return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features, data=audio_features,
...@@ -523,7 +528,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -523,7 +528,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_audio_input( def _process_audio_input(
self, audio_input: UltravoxAudioInputs) -> NestedTensors: self,
audio_input: UltravoxAudioInputs,
) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
if audio_input["type"] == "audio_embeds": if audio_input["type"] == "audio_embeds":
return audio_input["data"] return audio_input["data"]
...@@ -531,13 +538,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -531,13 +538,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
audio_features = pad_and_concat_to_dim3(audio_input["data"]) audio_features = pad_and_concat_to_dim3(audio_input["data"])
if isinstance(audio_input['lens'], list): # [B1, B2] -> [B1+B2]
# [B1, B2] -> [B1+B2] audio_lens = flatten_bn(audio_input['lens'], concat=True)
audio_lens = torch.cat(audio_input['lens']) audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
audio_token_len = torch.cat(audio_input['token_len'])
else:
audio_lens = flatten_bn(audio_input['lens'])
audio_token_len = flatten_bn(audio_input['token_len'])
embeddings = self._audio_features_to_embeddings( embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens) audio_features, audio_lens)
...@@ -554,7 +557,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -554,7 +557,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
# Apply mask and flatten # Apply mask and flatten
flattened_embeddings = embeddings[mask] flattened_embeddings = embeddings[mask]
return flattened_embeddings # Return one tensor per input audio
embed_lens = [
token_len_item.sum().item()
for token_len_item in audio_input['token_len']
]
return flattened_embeddings.split(embed_lens)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
...@@ -646,7 +654,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -646,7 +654,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def pad_and_concat_to_dim3( def pad_and_concat_to_dim3(
features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]] features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Pad and concatenate a list of tensors. Pad and concatenate a list of tensors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment