Unverified Commit cb55ad86 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate ultravox inputs to TensorSchema (#23503)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 712b273f
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" ...@@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16 _MAX_ENCODER_BATCH_SIZE = 16
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
"""Shape: `(batch_size, num_chunks, 80, M)`"""
lens: Union[torch.Tensor, list[torch.Tensor]]
""" """
Length of the audio frames. Used for attention mask in WhisperEncoder. Dimensions:
Shape: `(batch_size, num_chunks)` - b: batch size
- n: number of chunks
- t: Time frames (M)
- nmb: Number of mel bins
""" """
token_len: Union[torch.Tensor, list[torch.Tensor]] type: Literal["audio_features"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor],
list[list[torch.Tensor]]],
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})]
lens: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio tokens. Used for flattening the audio features."""
class UltravoxAudioEmbeddingInputs(TensorSchema):
""" """
Length of the audio tokens. Used for flattening the audio features. Dimensions:
Shape: `(batch_size, num_chunks)` - b: batch size
- na: number of audios
- afs: audio feature size
- hs: hidden size
""" """
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: NestedTensors data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" TensorShape("b", "na", "afs", "hs")]
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
...@@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return None return None
if audio_features is not None: if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_lens. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_token_len, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_token_len. "
f"Got type: {type(audio_features)}")
return UltravoxAudioFeatureInputs(type="audio_features", return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features, data=audio_features,
lens=audio_lens, lens=audio_lens,
token_len=audio_token_len) token_len=audio_token_len)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return UltravoxAudioEmbeddingInputs(type="audio_embeds", return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds) data=audio_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment