Unverified Commit cb55ad86 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate ultravox inputs to TensorSchema (#23503)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 712b273f
......@@ -4,7 +4,7 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import torch
from torch import nn
......@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
......@@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
"""Shape: `(batch_size, num_chunks, 80, M)`"""
lens: Union[torch.Tensor, list[torch.Tensor]]
class UltravoxAudioFeatureInputs(TensorSchema):
"""
Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size, num_chunks)`
Dimensions:
- b: batch size
- n: number of chunks
- t: Time frames (M)
- nmb: Number of mel bins
"""
token_len: Union[torch.Tensor, list[torch.Tensor]]
type: Literal["audio_features"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor],
list[list[torch.Tensor]]],
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})]
lens: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio tokens. Used for flattening the audio features."""
class UltravoxAudioEmbeddingInputs(TensorSchema):
"""
Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size, num_chunks)`
Dimensions:
- b: batch size
- na: number of audios
- afs: audio feature size
- hs: hidden size
"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "na", "afs", "hs")]
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
......@@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_lens. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_token_len, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_token_len. "
f"Got type: {type(audio_features)}")
return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features,
lens=audio_lens,
token_len=audio_token_len)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment