Unverified Commit e858bc4d authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Model] Add support for transformer-based Ultravox v0.7 projector (#30089)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent e3fbb6f1
...@@ -4,15 +4,21 @@ ...@@ -4,15 +4,21 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperEncoderLayer,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module): ...@@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module):
return audio_embeds return audio_embeds
class UltravoxProjector(nn.Module): class UltravoxFeedForwardProjector(nn.Module):
def __init__(self, config: UltravoxConfig): def __init__(self, config: UltravoxConfig):
super().__init__() super().__init__()
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
...@@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module): ...@@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module):
self.ln_mid = nn.Identity() self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out) self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor: def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features) audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features) audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features) hidden_states = self.linear_1(audio_features)
...@@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module): ...@@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module):
return hidden_states return hidden_states
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.config = SimpleNamespace(is_decoder=False)
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim_in = config.audio_config.hidden_size * config.stack_factor
projector_audio_config = copy.deepcopy(config.audio_config)
self.ln_pre = RMSNorm(dim_in)
self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)
self.embed_positions = nn.Embedding(
projector_audio_config.max_source_positions,
projector_audio_config.d_model,
)
self.layers = nn.ModuleList(
[
WhisperEncoderLayer(projector_audio_config)
for _ in range(config.num_projector_layers)
]
)
self.ln_post = RMSNorm(projector_audio_config.d_model)
self.linear_out = nn.Linear(
projector_audio_config.d_model, config.text_config.hidden_size
)
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
max_len_stacked = audio_features.shape[1]
attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
None, :
].lt(audio_token_len[:, None])
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, attention_mask.shape, audio_features.dtype
)
hidden_states = self.ln_pre(audio_features)
hidden_states = self.linear_in(hidden_states)
positions = self.embed_positions(
torch.arange(hidden_states.size(1), device=hidden_states.device)
)
hidden_states = hidden_states + positions
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.linear_out(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder): class ModifiedWhisperEncoder(WhisperEncoder):
""" """
Encoder portion of OpenAI's Whisper model. Encoder portion of OpenAI's Whisper model.
...@@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.", prefix="audio_tower.",
) )
) )
self.multi_modal_projector = UltravoxProjector(config) if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.wrapped_model_config, hf_config=config.wrapped_model_config,
...@@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
) )
def _audio_features_to_embeddings( def _audio_features_to_embeddings(
self, input_features: torch.Tensor, audio_lens: torch.Tensor self,
input_features: torch.Tensor,
audio_lens: torch.Tensor,
audio_token_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
audio_features = input_features.to(self.audio_tower.dtype) audio_features = input_features.to(self.audio_tower.dtype)
batch_size = audio_features.size(0) batch_size = audio_features.size(0)
...@@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
batch_features = batch_features.to(self.audio_tower.dtype) batch_features = batch_features.to(self.audio_tower.dtype)
# Process through projector # Process through projector
batch_embeddings = self.multi_modal_projector(batch_features) batch_embeddings = self.multi_modal_projector(
batch_features, audio_token_len[start:end]
)
audio_embeddings.append(batch_embeddings) audio_embeddings.append(batch_embeddings)
# Concatenate results # Concatenate results
...@@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
audio_lens = audio_input["lens"] audio_lens = audio_input["lens"]
audio_token_len = audio_input["token_len"] audio_token_len = audio_input["token_len"]
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens, audio_token_len
)
# We should flatten and concatenate embeddings based on token lengths # We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be # For example, with token_len = [4, 2, 3], flattened_embeddings will be
......
...@@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig): ...@@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
norm_init: float = 0.4, norm_init: float = 0.4,
projector_act: str = "swiglu", projector_act: str = "swiglu",
projector_ln_mid: bool = False, projector_ln_mid: bool = False,
num_projector_layers: int = 0,
**kwargs, **kwargs,
): ):
self.ignore_index = ignore_index self.ignore_index = ignore_index
...@@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig): ...@@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
self.norm_init = norm_init self.norm_init = norm_init
self.projector_act = projector_act self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid self.projector_ln_mid = projector_ln_mid
self.num_projector_layers = num_projector_layers
# N.B. May set the wrapped_model_config below. # N.B. May set the wrapped_model_config below.
self.text_model_id = text_model_id self.text_model_id = text_model_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment