Unverified Commit e858bc4d authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Model] Add support for transformer-based Ultravox v0.7 projector (#30089)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent e3fbb6f1
......@@ -4,15 +4,21 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import copy
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
import torch
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperEncoderLayer,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
......@@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module):
return audio_embeds
class UltravoxProjector(nn.Module):
class UltravoxFeedForwardProjector(nn.Module):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
......@@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module):
self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
......@@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module):
return hidden_states
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.config = SimpleNamespace(is_decoder=False)
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim_in = config.audio_config.hidden_size * config.stack_factor
projector_audio_config = copy.deepcopy(config.audio_config)
self.ln_pre = RMSNorm(dim_in)
self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)
self.embed_positions = nn.Embedding(
projector_audio_config.max_source_positions,
projector_audio_config.d_model,
)
self.layers = nn.ModuleList(
[
WhisperEncoderLayer(projector_audio_config)
for _ in range(config.num_projector_layers)
]
)
self.ln_post = RMSNorm(projector_audio_config.d_model)
self.linear_out = nn.Linear(
projector_audio_config.d_model, config.text_config.hidden_size
)
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
max_len_stacked = audio_features.shape[1]
attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
None, :
].lt(audio_token_len[:, None])
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, attention_mask.shape, audio_features.dtype
)
hidden_states = self.ln_pre(audio_features)
hidden_states = self.linear_in(hidden_states)
positions = self.embed_positions(
torch.arange(hidden_states.size(1), device=hidden_states.device)
)
hidden_states = hidden_states + positions
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.linear_out(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder):
"""
Encoder portion of OpenAI's Whisper model.
......@@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.",
)
)
self.multi_modal_projector = UltravoxProjector(config)
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
......@@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
def _audio_features_to_embeddings(
self, input_features: torch.Tensor, audio_lens: torch.Tensor
self,
input_features: torch.Tensor,
audio_lens: torch.Tensor,
audio_token_len: torch.Tensor,
) -> torch.Tensor:
audio_features = input_features.to(self.audio_tower.dtype)
batch_size = audio_features.size(0)
......@@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
batch_features = batch_features.to(self.audio_tower.dtype)
# Process through projector
batch_embeddings = self.multi_modal_projector(batch_features)
batch_embeddings = self.multi_modal_projector(
batch_features, audio_token_len[start:end]
)
audio_embeddings.append(batch_embeddings)
# Concatenate results
......@@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
audio_lens = audio_input["lens"]
audio_token_len = audio_input["token_len"]
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens, audio_token_len
)
# We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be
......
......@@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
norm_init: float = 0.4,
projector_act: str = "swiglu",
projector_ln_mid: bool = False,
num_projector_layers: int = 0,
**kwargs,
):
self.ignore_index = ignore_index
......@@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
self.norm_init = norm_init
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.num_projector_layers = num_projector_layers
# N.B. May set the wrapped_model_config below.
self.text_model_id = text_model_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment