Unverified Commit 36d450e3 authored by ramos's avatar ramos Committed by GitHub
Browse files

Adds FunAudioChat multimodal audio model support (#2) (#33058)


Signed-off-by: default avatarramos <49182011+nemoramo@users.noreply.github.com>
Signed-off-by: default avatarmayufeng <mayufeng@example.com>
Co-authored-by: default avatarmayufeng <mayufeng@example.com>
parent a2b877df
...@@ -117,6 +117,31 @@ def run_glmasr(question: str, audio_count: int) -> ModelRequestData: ...@@ -117,6 +117,31 @@ def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
) )
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
# NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
# writing. Pass a local model path via `--model`.
model_name = "funaudiochat"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
audio_in_prompt = "".join(
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
)
prompt = f"{audio_in_prompt}{question}"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech # Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somewhat different from what is # NOTE - the setting in this example are somewhat different from what is
...@@ -410,6 +435,7 @@ model_example_map = { ...@@ -410,6 +435,7 @@ model_example_map = {
"audioflamingo3": run_audioflamingo3, "audioflamingo3": run_audioflamingo3,
"gemma3n": run_gemma3n, "gemma3n": run_gemma3n,
"glmasr": run_glmasr, "glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech, "granite_speech": run_granite_speech,
"midashenglm": run_midashenglm, "midashenglm": run_midashenglm,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
...@@ -435,6 +461,12 @@ def parse_args(): ...@@ -435,6 +461,12 @@ def parse_args():
choices=model_example_map.keys(), choices=model_example_map.keys(),
help='Huggingface "model_type".', help='Huggingface "model_type".',
) )
parser.add_argument(
"--model",
type=str,
default=None,
help="Model ID or local path override. Required for funaudiochat.",
)
parser.add_argument( parser.add_argument(
"--num-prompts", type=int, default=1, help="Number of prompts to run." "--num-prompts", type=int, default=1, help="Number of prompts to run."
) )
...@@ -467,6 +499,9 @@ def main(args): ...@@ -467,6 +499,9 @@ def main(args):
if model not in model_example_map: if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.") raise ValueError(f"Model type {model} is not supported.")
if model == "funaudiochat" and not args.model:
raise ValueError("--model is required when --model-type=funaudiochat")
if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1: if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
raise ValueError( raise ValueError(
f"tensor_parallel_size must be a positive integer, " f"tensor_parallel_size must be a positive integer, "
...@@ -477,6 +512,8 @@ def main(args): ...@@ -477,6 +512,8 @@ def main(args):
req_data = model_example_map[model]( req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count question_per_audio_count[audio_count], audio_count
) )
if model == "funaudiochat":
req_data.engine_args.model = args.model
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
......
...@@ -692,6 +692,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -692,6 +692,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"baidu/ERNIE-4.5-VL-28B-A3B-PT", "baidu/ERNIE-4.5-VL-28B-A3B-PT",
trust_remote_code=True, trust_remote_code=True,
), ),
"FunAudioChatForConditionalGeneration": _HfExamplesInfo(
"funaudiochat", is_available_online=False
),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only FunAudioChat model compatible with HuggingFace weights.
FunAudioChat is a Qwen3 text model augmented with:
- a continuous audio encoder (Whisper-mel frontend + transformer)
- a discrete audio encoder (speech tokenizer + projector)
In the HF implementation, audio features are scattered into `<|AUDIO|>` token
positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D.
"""
from __future__ import annotations
import os
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizerFast, WhisperFeatureExtractor
from transformers.activations import get_activation
from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import _has_module
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
class _SinusoidsPositionEmbedding(nn.Module):
def __init__(self, length: int, channels: int, max_timescale: float = 10000.0):
super().__init__()
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2).float()
)
scaled_time = (
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
)
self.register_buffer(
"positional_embedding",
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
persistent=False,
)
class FunAudioChatAudioAttention(nn.Module):
"""Multi-headed attention used inside the continuous audio tower."""
def __init__(self, config: Any):
super().__init__()
self.embed_dim = int(config.d_model)
self.total_num_heads = int(config.encoder_attention_heads)
self.dropout = float(getattr(config, "attention_dropout", 0.0))
self.head_dim = self.embed_dim // self.total_num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
if self.head_dim * self.total_num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got embed_dim={self.embed_dim}, "
f"num_heads={self.total_num_heads})."
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False
self.is_causal = False
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
)
self.num_heads = self.qkv_proj.num_heads
self.num_kv_heads = self.qkv_proj.num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
prefix="funaudiochat_audio_tower.attn",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
with torch.no_grad():
if self.qkv_proj.bias is not None:
# HF FunAudioChat uses bias=False for k_proj. Ensure the missing
# shard starts as zeros, while allowing q/v shards to load.
self.qkv_proj.bias.zero_()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
del kwargs
del attention_mask
seq_length, _ = hidden_states.size()
qkv, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split(
[self.q_size, self.kv_size, self.kv_size], dim=-1
)
max_seqlen: torch.Tensor | None = None
if cu_seqlens is not None:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output = self.attn(
query_states.reshape(1, seq_length, self.q_size),
key_states.reshape(1, seq_length, self.kv_size),
value_states.reshape(1, seq_length, self.kv_size),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).reshape(seq_length, -1)
output, _ = self.out_proj(attn_output)
return output
class FunAudioChatAudioEncoderLayer(nn.Module):
def __init__(self, config: Any):
super().__init__()
self.embed_dim = int(config.d_model)
self.self_attn = FunAudioChatAudioAttention(config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = float(config.dropout)
self.activation_fn = get_activation(str(config.activation_function))
self.activation_dropout = float(config.activation_dropout)
self.fc1 = nn.Linear(self.embed_dim, int(config.encoder_ffn_dim))
self.fc2 = nn.Linear(int(config.encoder_ffn_dim), self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs: object,
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(
hidden_states, p=self.activation_dropout, training=self.training
)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
return (hidden_states,)
class FunAudioChatAudioEncoder(nn.Module):
"""Continuous audio tower."""
def __init__(self, config: Any):
super().__init__()
self.config = config
embed_dim = int(config.d_model)
self.num_mel_bins = int(config.num_mel_bins)
self.max_source_positions = int(config.max_source_positions)
self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
self.n_window = int(config.n_window)
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.layers = nn.ModuleList(
[
FunAudioChatAudioEncoderLayer(config)
for _ in range(int(config.encoder_layers))
]
)
self.ln_post = nn.LayerNorm(embed_dim)
self.avg_pooler = nn.AvgPool1d(2, stride=2)
self.proj = nn.Linear(embed_dim, int(config.output_dim))
self.positional_embedding = _SinusoidsPositionEmbedding(
self.max_source_positions, embed_dim
)
# Present in HF weights even if unused during S2T.
self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim))
@property
def dtype(self) -> torch.dtype:
return self.conv1.weight.dtype
def _prepare_attention_mask(
self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor
) -> torch.Tensor | None:
if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
(1, 1, seq_length, seq_length),
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
start = int(cu_seqlens[i - 1].item())
end = int(cu_seqlens[i].item())
attention_mask[..., start:end, start:end] = 0
return attention_mask
def forward(
self,
input_features: torch.Tensor,
feature_lens: torch.Tensor,
aftercnn_lens: torch.Tensor,
speech_maxlen: int,
**kwargs: object,
) -> BaseModelOutput:
# For max-length audio (300s => ~7500 speech frames at 25Hz), the
# Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the
# longest chunks). Require FlashAttention for such inputs to avoid OOM
# and performance cliffs.
if int(speech_maxlen) >= 7500:
if not _has_module("flash_attn"):
raise RuntimeError(
"FunAudioChat long audio (~300s) requires FlashAttention-2 "
"for the continuous audio tower, but `flash_attn` is not "
"installed in the runtime environment."
)
if not getattr(
self.layers[0].self_attn.attn, "is_flash_attn_backend", False
):
raise RuntimeError(
"FunAudioChat long audio (~300s) requires FlashAttention for the "
"continuous audio tower, but the selected MM encoder attention "
"backend is not FlashAttention."
)
# Handle empty / invalid items (feature_lens == 0) without crashing.
original_batch_size = int(feature_lens.size(0))
device = input_features.device
valid_mask = feature_lens > 0
valid_indices = torch.where(valid_mask)[0]
if valid_indices.numel() == 0:
output_dim = int(self.proj.out_features)
return BaseModelOutput(
last_hidden_state=torch.zeros(
(original_batch_size, speech_maxlen, output_dim),
device=device,
dtype=self.proj.weight.dtype,
)
)
input_features_list = input_features.split(feature_lens.tolist(), dim=1)
valid_input_features_list = [input_features_list[int(i)] for i in valid_indices]
valid_input_features = torch.cat(valid_input_features_list, dim=1)
valid_feature_lens = feature_lens[valid_mask]
valid_aftercnn_lens = aftercnn_lens[valid_mask]
chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long()
chunk_lengths_list: list[int] = []
full_chunk_len = self.n_window * 2
for i, length in enumerate(valid_feature_lens):
num_chunks_for_sample = int(chunk_num[i].item())
if num_chunks_for_sample == 0:
continue
chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1))
last_chunk_len = int(length.item()) % full_chunk_len
if last_chunk_len == 0:
last_chunk_len = full_chunk_len
chunk_lengths_list.append(last_chunk_len)
chunk_lengths = torch.tensor(
chunk_lengths_list, dtype=torch.long, device=device
)
chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1)
padded_feature, padded_mask, padded_mask_after_cnn = (
self.padded_and_mask_function(
chunk_list, chunk_lengths, padding_value=0, padding_side="right"
)
)
padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
padded_embed = padded_embed + self.positional_embedding.positional_embedding[
: padded_embed.shape[1], :
].unsqueeze(0).to(padded_embed.dtype)
hidden_states = padded_embed[padded_mask_after_cnn]
cu_seqlens = torch.cat(
(
torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
padded_mask_after_cnn.sum(1).cumsum(0),
)
).to(torch.int32)
for encoder_layer in self.layers:
(hidden_states,) = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
**kwargs,
)
hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0)
pooled_list: list[torch.Tensor] = []
pooled_lengths: list[int] = []
for each_audio_states in hidden_states_list:
seq_len = int(each_audio_states.shape[0])
if seq_len >= 2:
pooled = nn.functional.avg_pool1d(
each_audio_states.transpose(0, 1), kernel_size=2, stride=2
).transpose(0, 1)
else:
pooled = each_audio_states
pooled_list.append(pooled)
pooled_lengths.append(int(pooled.shape[0]))
pooled_concat = torch.cat(pooled_list, dim=0)
processed_concat = self.proj(self.ln_post(pooled_concat))
processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0))
output_dim = (
int(processed_audio_list[0].shape[-1])
if processed_audio_list
else int(self.proj.out_features)
)
output_hidden_states = torch.zeros(
(original_batch_size, speech_maxlen, output_dim),
dtype=processed_audio_list[0].dtype
if processed_audio_list
else self.proj.weight.dtype,
device=device,
)
for valid_idx, processed in zip(valid_indices, processed_audio_list):
seq_len = min(int(processed.shape[0]), int(speech_maxlen))
output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len]
return BaseModelOutput(last_hidden_state=output_hidden_states)
def padded_and_mask_function(
self,
tensor_list: Sequence[torch.Tensor],
tensor_len: torch.Tensor,
padding_value: float = 0.0,
padding_side: str = "right",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
max_len = int(tensor_len.max().item())
dim = int(tensor_list[0].shape[0])
padded_tensor = torch.full(
size=(len(tensor_list), dim, max_len),
fill_value=padding_value,
dtype=self.dtype,
device=tensor_list[0].device,
)
batch_mask = torch.zeros(
(len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device
)
for i, length in enumerate(tensor_len):
length_val = int(length.item())
batch_mask[i, :length_val] = 1
padded_tensor[i, :, :length_val] = tensor_list[i]
feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
max_len_after_cnn = int(feature_lens_after_cnn.max().item())
batch_mask_after_cnn = torch.zeros(
(len(tensor_len), max_len_after_cnn),
dtype=torch.long,
device=padded_tensor.device,
)
for i, length in enumerate(feature_lens_after_cnn):
batch_mask_after_cnn[i, : int(length.item())] = 1
if padding_side != "right":
raise NotImplementedError("Only right padding is supported.")
return (
padded_tensor,
batch_mask.unsqueeze(1).to(padded_tensor.dtype),
batch_mask_after_cnn.bool(),
)
# From the HF FunAudioChat implementation.
def _get_feat_extract_output_lengths(
self, input_lengths: torch.LongTensor
) -> tuple[torch.LongTensor, torch.LongTensor]:
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class FunAudioChatDiscreteEncoder(nn.Module):
"""Discrete audio encoder (speech tokenizer -> grouped embeddings)."""
def __init__(self, config: Any):
super().__init__()
self.padding_idx = int(config.pad_token_id)
self.group_size = int(config.group_size)
self.hidden_size = int(config.output_dim)
self.continuous_features_mode = getattr(
config, "continuous_features_mode", "add"
)
self.embed_tokens = nn.Embedding(
int(config.codebook_size), self.hidden_size, self.padding_idx
)
self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.continual_output_matching = nn.Linear(
self.hidden_size, self.hidden_size, bias=False
)
def forward(
self,
audio_ids: torch.Tensor,
continuous_audio_features: torch.Tensor | None = None,
continuous_audio_output_lengths: torch.Tensor | None = None,
feature_exist_mask: torch.Tensor | None = None,
) -> torch.Tensor:
del continuous_audio_output_lengths
inputs_embeds = self.embed_tokens(audio_ids)
hidden_states = inputs_embeds.reshape(
inputs_embeds.shape[0], -1, self.group_size * self.hidden_size
)
hidden_states = hidden_states.reshape(
hidden_states.shape[0], -1, self.group_size, self.hidden_size
).mean(dim=2)
hidden_states = self.output_matching(hidden_states)
if continuous_audio_features is not None:
continuous_audio_features = continuous_audio_features.reshape(
continuous_audio_features.shape[0],
-1,
self.group_size,
self.hidden_size,
).mean(dim=2)
continuous_audio_hidden_states = self.continual_output_matching(
continuous_audio_features
)
if feature_exist_mask is None:
feature_exist_mask = torch.ones(
(hidden_states.shape[0],),
dtype=torch.bool,
device=hidden_states.device,
)
if self.continuous_features_mode == "add":
hidden_states[feature_exist_mask] += continuous_audio_hidden_states
else:
hidden_states[feature_exist_mask] = continuous_audio_hidden_states
return hidden_states
def _get_feat_extract_output_lengths(
self, input_lengths: torch.LongTensor
) -> tuple[torch.LongTensor, torch.LongTensor]:
output_lengths = (input_lengths + self.group_size - 1) // self.group_size
return input_lengths, output_lengths
class FunAudioChatProcessingInfo(BaseProcessingInfo):
token_fps: int = 25
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
def get_target_channels(self) -> int:
return 1
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int] | None:
# The discrete audio encoder downsamples 25Hz frames with group_size=5,
# so for a 300s clip the max number of `<|AUDIO|>` placeholders is 1500.
cfg = self.get_hf_config()
audio_cfg = getattr(cfg, "audio_config", None)
max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
return {"audio": max_audio_tokens}
@cached_property
def feature_extractor(self) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(self.model_id)
@cached_property
def speech_tokenizer(self) -> PreTrainedTokenizerFast:
return PreTrainedTokenizerFast.from_pretrained(
self.model_id, subfolder="speech_tokenizer"
)
def get_feature_extractor(self) -> WhisperFeatureExtractor:
return self.feature_extractor
def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
return self.speech_tokenizer
def get_audio_group_size(self) -> int:
cfg = self.get_hf_config()
audio_cfg = getattr(cfg, "audio_config", None)
return int(getattr(audio_cfg, "group_size", 5))
class FunAudioChatDummyInputsBuilder(
BaseDummyInputsBuilder[FunAudioChatProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios)
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = int(feature_extractor.sampling_rate)
# Dummy inputs are used for profiling; construct the worst-case audio
# length that maximizes the number of encoder tokens.
cfg = self.info.get_hf_config()
audio_cfg = getattr(cfg, "audio_config", None)
max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
group_size = self.info.get_audio_group_size()
token_fps = int(getattr(self.info, "token_fps", 25))
target_num_frames = max(1, max_audio_tokens) * max(1, group_size)
audio_len = max(
1,
(target_num_frames * sampling_rate + token_fps - 1) // token_fps,
)
num_audios = int(mm_counts.get("audio", 0))
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
class FunAudioChatMultiModalProcessor(
BaseMultiModalProcessor[FunAudioChatProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=int(feature_extractor.sampling_rate),
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])
audios = mm_data.get("audios", [])
if not audios:
return BatchFeature({"input_ids": input_ids})
feature_extractor = self.info.get_feature_extractor()
sr = int(feature_extractor.sampling_rate)
min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400)
wavs: list[np.ndarray] = []
speech_strs: list[str] = []
speech_tokenizer = self.info.get_speech_tokenizer()
pad_token = speech_tokenizer.pad_token or "<|audio_pad|>"
for audio in audios:
if isinstance(audio, torch.Tensor):
audio = audio.detach().cpu().numpy()
audio_np = np.asarray(audio, dtype=np.float32)
if min_samples > 0 and audio_np.shape[0] < min_samples:
audio_np = np.pad(
audio_np, (0, min_samples - audio_np.shape[0]), mode="constant"
)
wavs.append(audio_np)
num_frames = int(
(float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps)
)
speech_strs.append(pad_token * max(1, int(num_frames)))
audio_group_size = self.info.get_audio_group_size()
speech_inputs = speech_tokenizer(
speech_strs,
return_attention_mask=True,
return_token_type_ids=False,
padding=True,
pad_to_multiple_of=audio_group_size,
return_tensors="pt",
)
wav_inputs = feature_extractor(
wavs,
sampling_rate=sr,
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
mm_inputs: dict[str, torch.Tensor] = {
"speech_ids": speech_inputs["input_ids"],
"speech_attention_mask": speech_inputs["attention_mask"],
"input_features": wav_inputs["input_features"],
"feature_attention_mask": wav_inputs["attention_mask"],
"feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool),
}
return BatchFeature({"input_ids": input_ids, **mm_inputs})
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return {
"speech_ids": MultiModalFieldConfig.batched("audio"),
"speech_attention_mask": MultiModalFieldConfig.batched("audio"),
"input_features": MultiModalFieldConfig.batched("audio"),
"feature_attention_mask": MultiModalFieldConfig.batched("audio"),
"feature_exist_mask": MultiModalFieldConfig.batched("audio"),
}
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = "<|AUDIO|>"
audio_token_id = vocab[audio_token]
out_mm_data = out_mm_kwargs.get_data()
speech_attention_mask = out_mm_data.get("speech_attention_mask")
if speech_attention_mask is None:
audio_output_lengths: list[int] = []
else:
assert isinstance(speech_attention_mask, torch.Tensor)
speech_lengths = speech_attention_mask.sum(-1)
group_size = self.info.get_audio_group_size()
audio_output_lengths = (
(speech_lengths + group_size - 1) // group_size
).tolist()
def get_replacement_funaudiochat(item_idx: int):
num_features = (
int(audio_output_lengths[item_idx]) if audio_output_lengths else 1
)
if num_features <= 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
raise ValueError(
f"The audio (len={audio_len}) is too short to be "
"represented inside the model"
)
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_funaudiochat,
)
]
@MULTIMODAL_REGISTRY.register_processor(
FunAudioChatMultiModalProcessor,
info=FunAudioChatProcessingInfo,
dummy_inputs=FunAudioChatDummyInputsBuilder,
)
class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):
return "<|audio_bos|><|AUDIO|><|audio_eos|>"
raise ValueError("Only audio modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "audio"):
self.continuous_audio_tower = FunAudioChatAudioEncoder(config.audio_config)
self.audio_tower = FunAudioChatDiscreteEncoder(config.audio_config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen3ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _get_continuous_audio_features(
self,
input_features: torch.Tensor,
feature_attention_mask: torch.Tensor,
speech_maxlen: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Align mask and features to avoid indexing errors when padding differs.
if (
input_features.dim() == 3
and feature_attention_mask.shape[1] != input_features.shape[-1]
):
min_len = min(
int(feature_attention_mask.shape[1]), int(input_features.shape[-1])
)
feature_attention_mask = feature_attention_mask[:, :min_len]
input_features = input_features[:, :, :min_len]
feature_lens = torch.sum(feature_attention_mask, dim=1)
flat_features = input_features.permute(0, 2, 1)[
feature_attention_mask.bool()
].permute(1, 0)
audio_feat_lengths, audio_output_lengths = (
self.continuous_audio_tower._get_feat_extract_output_lengths(feature_lens)
)
audio_outputs = self.continuous_audio_tower(
flat_features,
feature_lens=feature_lens,
aftercnn_lens=audio_feat_lengths,
speech_maxlen=speech_maxlen,
)
return audio_outputs.last_hidden_state, audio_output_lengths
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
speech_ids = kwargs.get("speech_ids")
speech_attention_mask = kwargs.get("speech_attention_mask")
input_features = kwargs.get("input_features")
feature_attention_mask = kwargs.get("feature_attention_mask")
feature_exist_mask = kwargs.get("feature_exist_mask")
if speech_ids is None:
return []
pad_id = int(getattr(self.audio_tower, "padding_idx", 0))
if not isinstance(speech_ids, torch.Tensor):
if (
isinstance(speech_ids, (list, tuple))
and len(speech_ids) > 0
and all(isinstance(t, torch.Tensor) for t in speech_ids)
):
speech_ids_tensors = []
for t in speech_ids:
if t.dim() == 2 and t.shape[0] == 1:
t = t.squeeze(0)
if t.dim() != 1:
raise TypeError(
"FunAudioChat speech_ids must be a 1D tensor per item "
f"(got shape={tuple(t.shape)})"
)
speech_ids_tensors.append(t)
speech_ids = nn.utils.rnn.pad_sequence(
speech_ids_tensors,
batch_first=True,
padding_value=pad_id,
)
else:
raise TypeError(
"FunAudioChat speech_ids must be a Tensor or a sequence of Tensors "
f"(got {type(speech_ids)})"
)
if speech_attention_mask is None:
speech_attention_mask = speech_ids.ne(pad_id).to(dtype=torch.int64)
if not isinstance(speech_attention_mask, torch.Tensor):
if (
isinstance(speech_attention_mask, (list, tuple))
and len(speech_attention_mask) > 0
and all(isinstance(t, torch.Tensor) for t in speech_attention_mask)
):
mask_tensors = []
for t in speech_attention_mask:
if t.dim() == 2 and t.shape[0] == 1:
t = t.squeeze(0)
if t.dim() != 1:
raise TypeError(
"FunAudioChat speech_attention_mask must be a 1D tensor "
f"per item (got shape={tuple(t.shape)})"
)
mask_tensors.append(t)
speech_attention_mask = nn.utils.rnn.pad_sequence(
mask_tensors,
batch_first=True,
padding_value=0,
)
else:
raise TypeError(
"FunAudioChat speech_attention_mask must be a Tensor or a "
f"sequence of Tensors (got {type(speech_attention_mask)})"
)
debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
if debug:
print(
f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
flush=True,
)
attn_impl = getattr(
self.continuous_audio_tower.config, "_attn_implementation", None
)
print(
f"[FunAudioChat] audio_attn_impl={attn_impl}",
flush=True,
)
if hasattr(self.continuous_audio_tower, "conv1"):
conv1_w = self.continuous_audio_tower.conv1.weight
print(
f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
flush=True,
)
try:
attn0 = self.continuous_audio_tower.layers[0].self_attn
q_norm = float(attn0.q_proj.weight.norm().item())
k_norm = float(attn0.k_proj.weight.norm().item())
v_norm = float(attn0.v_proj.weight.norm().item())
o_norm = float(attn0.out_proj.weight.norm().item())
print(
f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
f"k_norm={k_norm:.6g} "
f"v_norm={v_norm:.6g} "
f"o_norm={o_norm:.6g}",
flush=True,
)
except Exception:
pass
if isinstance(input_features, torch.Tensor):
print(
f"[FunAudioChat] input_features={tuple(input_features.shape)}",
flush=True,
)
if isinstance(feature_attention_mask, torch.Tensor):
print(
"[FunAudioChat] feature_attention_mask="
f"{tuple(feature_attention_mask.shape)}",
flush=True,
)
group_size = int(self.audio_tower.group_size)
speech_maxlen = int(speech_ids.shape[-1])
# Ensure token length is divisible by group_size.
target_len = ((speech_maxlen + group_size - 1) // group_size) * group_size
if target_len > speech_maxlen:
pad_id = int(self.audio_tower.padding_idx)
pad_len = target_len - speech_maxlen
speech_ids = nn.functional.pad(speech_ids, (0, pad_len), value=pad_id)
speech_attention_mask = nn.functional.pad(
speech_attention_mask, (0, pad_len), value=0
)
speech_maxlen = int(speech_ids.shape[-1])
continuous_audio_features = None
continuous_audio_output_lengths = None
if input_features is not None and feature_attention_mask is not None:
assert isinstance(input_features, torch.Tensor)
assert isinstance(feature_attention_mask, torch.Tensor)
continuous_audio_features, continuous_audio_output_lengths = (
self._get_continuous_audio_features(
input_features=input_features,
feature_attention_mask=feature_attention_mask,
speech_maxlen=speech_maxlen,
)
)
if feature_exist_mask is None:
feature_exist_mask = torch.ones(
(speech_ids.shape[0],), dtype=torch.bool, device=speech_ids.device
)
assert isinstance(feature_exist_mask, torch.Tensor)
audio_features = self.audio_tower(
speech_ids,
continuous_audio_features=continuous_audio_features,
continuous_audio_output_lengths=continuous_audio_output_lengths,
feature_exist_mask=feature_exist_mask,
)
_, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
speech_attention_mask.sum(-1)
)
lengths = audio_output_lengths.tolist()
embeds = tuple(
audio_features[i, : int(length)] for i, length in enumerate(lengths)
)
if debug:
embed_lens = [int(t.shape[0]) for t in embeds]
print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
if embeds:
t0 = embeds[0]
print(
f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
f"nan={bool(torch.isnan(t0).any())} "
f"norm={float(t0.norm().item()):.6g}",
flush=True,
)
dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
if (
dump_path
and speech_ids.shape[0] == 1
and len(embeds) == 1
and embed_lens[0] > 10
):
if not os.path.exists(dump_path):
np.save(dump_path, embeds[0].detach().float().cpu().numpy())
print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
cont_path = dump_path.replace(".npy", "_cont.npy")
if continuous_audio_features is not None and not os.path.exists(
cont_path
):
np.save(
cont_path,
continuous_audio_features.detach().float().cpu().numpy(),
)
print(
f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
)
return embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
del kwargs
if intermediate_tensors is not None:
inputs_embeds = None
return self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["audio_invert_tower."])
return loader.load_weights(weights)
...@@ -312,6 +312,10 @@ _MULTIMODAL_MODELS = { ...@@ -312,6 +312,10 @@ _MULTIMODAL_MODELS = {
"ernie45_vl", "ernie45_vl",
"Ernie4_5_VLMoeForConditionalGeneration", "Ernie4_5_VLMoeForConditionalGeneration",
), ),
"FunAudioChatForConditionalGeneration": (
"funaudiochat",
"FunAudioChatForConditionalGeneration",
),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ( "Gemma3nForConditionalGeneration": (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Literal from typing import Literal
...@@ -195,6 +196,13 @@ class AudioResampler: ...@@ -195,6 +196,13 @@ class AudioResampler:
raise RuntimeError( raise RuntimeError(
"Audio resampling is not supported when `target_sr` is not provided" "Audio resampling is not supported when `target_sr` is not provided"
) )
if math.isclose(
float(orig_sr),
float(self.target_sr),
rel_tol=0.0,
abs_tol=1e-6,
):
return audio
if self.method == "librosa": if self.method == "librosa":
return resample_audio_librosa( return resample_audio_librosa(
audio, orig_sr=orig_sr, target_sr=self.target_sr audio, orig_sr=orig_sr, target_sr=self.target_sr
......
...@@ -77,6 +77,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -77,6 +77,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config", deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config", deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig", flex_olmo="FlexOlmoConfig",
funaudiochat="FunAudioChatConfig",
hunyuan_vl="HunYuanVLConfig", hunyuan_vl="HunYuanVLConfig",
isaac="IsaacConfig", isaac="IsaacConfig",
kimi_linear="KimiLinearConfig", kimi_linear="KimiLinearConfig",
......
...@@ -22,6 +22,8 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -22,6 +22,8 @@ _CLASS_TO_MODULE: dict[str, str] = {
"DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr",
"EAGLEConfig": "vllm.transformers_utils.configs.eagle", "EAGLEConfig": "vllm.transformers_utils.configs.eagle",
"FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo", "FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo",
"FunAudioChatConfig": "vllm.transformers_utils.configs.funaudiochat",
"FunAudioChatAudioEncoderConfig": "vllm.transformers_utils.configs.funaudiochat",
"HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl",
...@@ -65,6 +67,8 @@ __all__ = [ ...@@ -65,6 +67,8 @@ __all__ = [
"DotsOCRConfig", "DotsOCRConfig",
"EAGLEConfig", "EAGLEConfig",
"FlexOlmoConfig", "FlexOlmoConfig",
"FunAudioChatConfig",
"FunAudioChatAudioEncoderConfig",
"HunYuanVLConfig", "HunYuanVLConfig",
"HunYuanVLTextConfig", "HunYuanVLTextConfig",
"HunYuanVLVisionConfig", "HunYuanVLVisionConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from transformers import PretrainedConfig
# NOTE: Temporary shim for FunAudioChat checkpoints.
# These checkpoints use `model_type="funaudiochat"`, which is not currently
# recognized by released Transformers, and the public checkpoint does not
# provide an `auto_map` to enable `trust_remote_code=True`.
# Remove this file once Transformers adds native support (or the checkpoint
# provides an `auto_map`) and vLLM can rely on `AutoConfig.from_pretrained()`.
class FunAudioChatAudioEncoderConfig(PretrainedConfig):
model_type = "funaudiochat_audio_encoder"
def __init__(
self,
_attn_implementation: str | None = None,
num_mel_bins: int = 128,
encoder_layers: int = 32,
encoder_attention_heads: int = 20,
encoder_ffn_dim: int = 5120,
d_model: int = 1280,
dropout: float = 0.0,
attention_dropout: float = 0.0,
activation_function: str = "gelu",
activation_dropout: float = 0.0,
scale_embedding: bool = False,
initializer_range: float = 0.02,
max_source_positions: int = 1500,
n_window: int = 100,
output_dim: int = 3584,
bos_token_id: int | None = None,
codebook_size: int | None = None,
continuous_features_mode: str = "replace",
crq_transformer_config: dict | None = None,
eos_token_id: int | None = None,
group_size: int = 5,
enable_audio_invert_tower: bool = True,
pad_token_id: int | None = None,
**kwargs,
) -> None:
attn_impl = kwargs.pop("_attn_implementation", None) or _attn_implementation
super().__init__(**kwargs)
# Match HF default for attention implementation selection.
self._attn_implementation = attn_impl or "sdpa"
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.num_hidden_layers = encoder_layers
self.initializer_range = initializer_range
self.scale_embedding = scale_embedding
self.max_source_positions = max_source_positions
self.n_window = n_window
self.output_dim = output_dim
self.bos_token_id = bos_token_id
self.codebook_size = codebook_size
self.continuous_features_mode = continuous_features_mode
self.crq_transformer_config = crq_transformer_config
self.eos_token_id = eos_token_id
self.group_size = group_size
self.enable_audio_invert_tower = enable_audio_invert_tower
self.pad_token_id = pad_token_id
class FunAudioChatConfig(PretrainedConfig):
model_type = "funaudiochat"
attribute_map = {
"audio_token_id": "audio_token_index",
}
def __init__(
self,
audio_config: PretrainedConfig | dict | None = None,
text_config: PretrainedConfig | dict | None = None,
audio_token_index: int = 151646,
ignore_index: int = -100,
hidden_size: int | None = None,
**kwargs,
) -> None:
self.audio_token_index = audio_token_index
self.ignore_index = ignore_index
if isinstance(audio_config, dict):
audio_config.setdefault(
"model_type", FunAudioChatAudioEncoderConfig.model_type
)
audio_config = FunAudioChatAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = FunAudioChatAudioEncoderConfig()
self.audio_config = audio_config
if isinstance(text_config, dict):
# Default to qwen2 for backwards compatibility; FunAudioChat uses
# qwen3 in practice for recent checkpoints.
text_config.setdefault("model_type", "qwen2")
import transformers
text_cls = transformers.CONFIG_MAPPING[text_config["model_type"]]
text_config = text_cls(**text_config)
elif text_config is None:
import transformers
text_config = transformers.CONFIG_MAPPING["qwen2"]()
self.text_config = text_config
self.hidden_size = (
int(self.text_config.hidden_size)
if hidden_size is None
else int(hidden_size)
)
super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment