Unverified Commit 21dfb842 authored by AllenDou's avatar AllenDou Committed by GitHub
Browse files

[model] support FunASR model (#33247)


Signed-off-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
parent d1b837f0
......@@ -790,6 +790,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------|
| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | |
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
......
......@@ -26,7 +26,9 @@ from openai import AsyncOpenAI, OpenAI
from vllm.assets.audio import AudioAsset
def sync_openai(audio_path: str, client: OpenAI, model: str):
def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
):
"""
Perform synchronous transcription using OpenAI-compatible API.
"""
......@@ -40,7 +42,7 @@ def sync_openai(audio_path: str, client: OpenAI, model: str):
# Additional sampling params not provided by OpenAI API.
extra_body=dict(
seed=4419,
repetition_penalty=1.3,
repetition_penalty=repetition_penalty,
),
)
print("transcription result [sync]:", transcription.text)
......@@ -129,7 +131,12 @@ def main(args):
print(f"Using model: {model}")
# Run the synchronous function
sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model)
sync_openai(
audio_path=args.audio_path if args.audio_path else mary_had_lamb,
client=client,
model=model,
repetition_penalty=args.repetition_penalty,
)
# Run the asynchronous function
if "openai" in model:
......@@ -161,5 +168,11 @@ if __name__ == "__main__":
default=None,
help="The path to the audio file to transcribe.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.3,
help="repetition penalty",
)
args = parser.parse_args()
main(args)
......@@ -713,6 +713,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
trust_remote_code=True,
),
"FunASRForConditionalGeneration": _HfExamplesInfo(
"allendou/Fun-ASR-Nano-2512-vllm",
is_available_online=False,
),
"FunAudioChatForConditionalGeneration": _HfExamplesInfo(
"funaudiochat", is_available_online=False
),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
BatchFeature,
Qwen3Config,
)
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsTranscription,
_require_is_multimodal,
)
from .qwen3 import Qwen3Model
from .utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
logger = init_logger(__name__)
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, nout, dim=-1):
super().__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x: torch.Tensor):
if self.dim == -1:
return super().forward(x)
return super().forward(x.transpose(self.dim, -1)).transpose(self.dim, -1)
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size: int,
size: int,
self_attn: nn.Module,
feed_forward: nn.Module,
normalize_before=True,
):
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor | None = None,
cache=None,
mask_shfit_chunk=None,
mask_att_chunk_encoder=None,
):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
if self.in_size == self.size:
hidden_states = residual + self.self_attn(
hidden_states,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
else:
hidden_states = self.self_attn(
hidden_states,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = residual + self.feed_forward(hidden_states)
return hidden_states, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class MultiHeadedAttentionSANM(nn.Module):
def __init__(
self,
n_head: int,
in_feat: int,
n_feat: int,
kernel_size: int,
sanm_shift: int = 0,
):
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.out_proj = ReplicatedLinear(
input_size=n_feat,
output_size=n_feat,
bias=True,
)
self.linear_q_k_v = ReplicatedLinear(
input_size=in_feat,
output_size=n_feat * 3,
bias=True,
)
self.attn = None
self.fsmn_block = nn.Conv1d(
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
)
# padding
left_padding = (kernel_size - 1) // 2
if sanm_shift > 0:
left_padding = left_padding + sanm_shift
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward_fsmn(
self,
inputs: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk: torch.Tensor = None,
):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
if mask is not None:
x = x * mask
return x
def forward_qkv(self, x: torch.Tensor):
b, t, d = x.size()
q_k_v, _ = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)
return q_h, k_h, v_h, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor,
mask_att_chunk_encoder: torch.Tensor = None,
):
n_batch = value.size(0)
if mask is not None:
if mask_att_chunk_encoder is not None:
mask = mask * mask_att_chunk_encoder
mask = mask.unsqueeze(1).eq(0)
min_value = -float("inf")
scores = scores.masked_fill(mask, min_value)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(scores, dim=-1)
p_attn = attn
x = torch.matmul(p_attn, value)
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
out, _ = self.out_proj(x)
return out
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk: torch.Tensor = None,
mask_att_chunk_encoder: torch.Tensor = None,
):
q_h, k_h, v_h, v = self.forward_qkv(hidden_states)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
class SinusoidalPositionEncoder(torch.nn.Module):
def __init__(self, d_model=80):
super().__init__()
def encode(
self,
positions: torch.Tensor = None,
depth: int = None,
dtype: torch.dtype = torch.float32,
):
batch_size = positions.size(0)
positions = positions.type(dtype)
device = positions.device
log_timescale_increment = torch.log(
torch.tensor([10000], dtype=dtype, device=device)
) / (depth / 2 - 1)
inv_timescales = torch.exp(
torch.arange(depth / 2, device=device).type(dtype)
* (-log_timescale_increment)
)
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
inv_timescales, [1, 1, -1]
)
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
return encoding.type(dtype)
def forward(self, hidden_states: torch.Tensor):
batch_size, timesteps, input_dim = hidden_states.size()
positions = torch.arange(1, timesteps + 1, device=hidden_states.device)[None, :]
position_encoding = self.encode(positions, input_dim, hidden_states.dtype).to(
hidden_states.device
)
return hidden_states + position_encoding
class SenseVoiceEncoderSmall(nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
tp_blocks: int = 0,
attention_dropout_rate: float = 0.0,
normalize_before: bool = True,
kernel_size: int = 11,
sanm_shift: int = 0,
**kwargs,
):
super().__init__()
self._output_size = output_size
self.embed = SinusoidalPositionEncoder()
self.normalize_before = normalize_before
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
)
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
kernel_size,
sanm_shift,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
kernel_size,
sanm_shift,
)
self.encoders0 = nn.ModuleList(
[
EncoderLayerSANM(
input_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
)
for i in range(1)
]
)
self.encoders = nn.ModuleList(
[
EncoderLayerSANM(
output_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
)
for i in range(num_blocks - 1)
]
)
self.tp_encoders = nn.ModuleList(
[
EncoderLayerSANM(
output_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
)
for i in range(tp_blocks)
]
)
self.after_norm = LayerNorm(output_size)
self.tp_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
):
maxlen = xs_pad.shape[1]
masks = sequence_mask(
ilens, maxlen=maxlen, dtype=ilens.dtype, device=ilens.device
)[:, None, :]
xs_pad *= self.output_size() ** 0.5
xs_pad = self.embed(xs_pad)
for layer_idx, encoder_layer in enumerate(self.encoders0):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1).int()
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.tp_norm(xs_pad)
return xs_pad, olens
class PositionwiseFeedForward(nn.Module):
def __init__(self, idim: int, hidden_units: int):
super().__init__()
self.w_1 = ColumnParallelLinear(
input_size=idim,
output_size=hidden_units,
bias=True,
)
self.w_2 = RowParallelLinear(
input_size=hidden_units,
output_size=idim,
bias=True,
)
self.activation = _ACTIVATION_REGISTRY["relu"]
def forward(self, hidden_states: torch.Tensor):
hidden_states, _ = self.w_1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states, _ = self.w_2(hidden_states)
return hidden_states
class EncoderLayer(nn.Module):
def __init__(
self,
size: int,
self_attn: nn.Module,
feed_forward: nn.Module,
):
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = residual + self.self_attn(hidden_states, None, None)
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = residual + self.feed_forward(hidden_states)
return hidden_states
class FunASRAudioAttention(nn.Module):
def __init__(
self,
num_heads: int,
embed_dim: int,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
tp_size = get_tensor_model_parallel_world_size()
self.num_local_heads = self.num_heads // tp_size
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
total_num_kv_heads=self.num_heads,
bias=True,
prefix=f"{prefix}.qkv",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
bias=True,
prefix=f"{prefix}.out_proj",
)
self.attn = MMEncoderAttention(
num_heads=self.num_local_heads,
head_size=self.head_dim,
scale=self.scaling,
)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor | None,
) -> torch.Tensor:
bs, seq_length, _ = hidden_states.size()
qkv, _ = self.qkv(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(bs, seq_length, -1, self.head_dim)
k = k.view(bs, seq_length, -1, self.head_dim)
v = v.view(bs, seq_length, -1, self.head_dim)
attn_output = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.view(bs, seq_length, -1)
output, _ = self.out_proj(attn_output)
return output
class Transformer(nn.Module):
def __init__(
self,
downsample_rate=2,
encoder_dim=1280,
llm_dim=4096,
ffn_dim: int = 2048,
prefix: str = "",
**kwargs,
):
super().__init__()
self.k = downsample_rate
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.linear1 = ColumnParallelLinear(
input_size=self.encoder_dim * self.k,
output_size=ffn_dim,
bias=True,
)
self.relu = nn.ReLU()
self.linear2 = RowParallelLinear(
input_size=ffn_dim,
output_size=self.llm_dim,
bias=True,
)
self.blocks = None
if kwargs.get("n_layer", 2) > 0:
self.blocks = nn.ModuleList(
[
EncoderLayer(
llm_dim,
FunASRAudioAttention(
kwargs.get("attention_heads", 8),
llm_dim,
prefix=f"{prefix}.self_attn",
),
PositionwiseFeedForward(
llm_dim,
llm_dim // 4,
),
)
for _ in range(kwargs.get("n_layer", 2))
]
)
def forward(self, hidden_states: torch.Tensor, ilens: int = 0):
batch_size, seq_len, dim = hidden_states.size()
chunk_num = (seq_len - 1) // self.k + 1
pad_num = chunk_num * self.k - seq_len
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num, 0, 0), value=0.0)
seq_len = hidden_states.size(1)
hidden_states = hidden_states.contiguous()
hidden_states = hidden_states.view(batch_size, chunk_num, dim * self.k)
hidden_states, _ = self.linear1(hidden_states)
hidden_states = self.relu(hidden_states)
hidden_states, _ = self.linear2(hidden_states)
olens = None
olens = (ilens - 1) // self.k + 1
if self.blocks is not None:
for layer, block in enumerate(self.blocks):
hidden_states = block(hidden_states)
return hidden_states, olens
class FunASRAudioInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- nmb: Number of mel bins
- t: Time frames (M)
"""
input_features: Annotated[
list[torch.Tensor] | None,
TensorShape("b", "nmb", "t"),
]
speech_lengths: Annotated[
list[torch.Tensor] | None,
TensorShape("b"),
]
class FunASREncoder(nn.Module):
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
):
super().__init__()
self.audio_encoder = SenseVoiceEncoderSmall(
input_size=560, **vllm_config.model_config.hf_config.audio_encoder_conf
)
self.audio_adaptor = Transformer(
downsample_rate=1,
use_low_frame_rate=True,
ffn_dim=2048,
llm_dim=1024,
encoder_dim=512,
n_layer=2,
freeze=True,
prefix=maybe_prefix(prefix, "audio_encoder"),
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights with mapping from HuggingFace format."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("self_attn.qkv.", "self_attn.q_proj.", "q"),
("self_attn.qkv.", "self_attn.k_proj.", "k"),
("self_attn.qkv.", "self_attn.v_proj.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class FunASRModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.encoder = FunASREncoder(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "encoder")
)
self.decoder = Qwen3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return decoder_outputs
def get_encoder_outputs(
self,
speech: torch.Tensor | list[torch.Tensor] | None,
speech_lengths: torch.Tensor | list[torch.Tensor] | None,
) -> torch.Tensor | None:
self.feat_permute = False
if self.feat_permute:
encoder_out, encoder_out_lens = self.encoder.audio_encoder(
speech.permute(0, 2, 1), speech_lengths
)
else:
encoder_out, encoder_out_lens = self.encoder.audio_encoder(
speech, speech_lengths
)
encoder_out, encoder_out_lens = self.encoder.audio_adaptor(
encoder_out, encoder_out_lens
)
return encoder_out
class FunASRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Qwen3Config:
return self.ctx.get_hf_config(Qwen3Config)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_feature_extractor(self, **kwargs: object) -> FunASRFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, FunASRFeatureExtractor)
return feature_extractor
def get_target_channels(self) -> int:
return 1
def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|AUDIO|>" * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
speech_lengths=MultiModalFieldConfig.batched("audio"),
fake_token_len=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_token_id = vocab[audio_token]
out_mm_data = out_mm_kwargs.get_data()
fake_token_len = out_mm_data.get("fake_token_len")
if fake_token_len is None:
audio_output_lengths = []
else:
assert isinstance(fake_token_len, torch.Tensor)
audio_output_lengths = fake_token_len.tolist()
def get_replacement_qwen2_audio(item_idx: int):
if audio_output_lengths:
num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_qwen2_audio,
)
]
@MULTIMODAL_REGISTRY.register_processor(
FunASRMultiModalProcessor,
info=FunASRProcessingInfo,
dummy_inputs=FunASRDummyInputsBuilder,
)
class FunASRForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"linear_q.": "q_proj.",
"linear_k.": "k_proj.",
"linear_v.": "v_proj.",
"linear_out.": "out_proj.",
}
)
supports_transcription_only = True
supports_segment_timestamp = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod
def validate_language(cls, language: str | None) -> str | None:
if language is None:
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning(
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest."
)
language = "en"
return super().validate_language(language)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the funasr prompt"
)
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
prompt = {
"prompt": funasr_prompt,
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
}
return cast(PromptType, prompt)
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
sample_rate=processor.feature_extractor.sampling_rate,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
processor = cached_processor_from_config(model_config)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = FunASRModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
if config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
decoder_outputs = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
speech = audio_input["input_features"]
speech_lengths = audio_input["speech_lengths"]
enc_output = self.model.get_encoder_outputs(
speech=speech, speech_lengths=speech_lengths
)
return enc_output
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=_require_is_multimodal(is_multimodal),
)
def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
input_features = kwargs.pop("input_features", None)
speech_lengths = kwargs.pop("speech_lengths", None)
if input_features is not None:
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
if speech_lengths is not None:
speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths)
return FunASRAudioInputs(
input_features=input_features, speech_lengths=speech_lengths
)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
)
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
"""
for name, weight in weights:
if name.endswith(".k_proj.weight"):
bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)]
else:
yield name, weight
......@@ -325,6 +325,7 @@ _MULTIMODAL_MODELS = {
"ernie45_vl",
"Ernie4_5_VLMoeForConditionalGeneration",
),
"FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501
"FunAudioChatForConditionalGeneration": (
"funaudiochat",
"FunAudioChatForConditionalGeneration",
......
......@@ -10,6 +10,7 @@ reasons:
from vllm.transformers_utils.processors.bagel import BagelProcessor
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
from vllm.transformers_utils.processors.funasr_processor import FunASRProcessor
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
......@@ -18,6 +19,7 @@ from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
__all__ = [
"BagelProcessor",
"DeepseekVLV2Processor",
"FunASRProcessor",
"HunYuanVLProcessor",
"HunYuanVLImageProcessor",
"OvisProcessor",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import torch.nn as nn
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
from transformers import (
AutoFeatureExtractor,
AutoProcessor,
BatchFeature,
)
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.processing_utils import ProcessorMixin
from transformers.utils import TensorType
from vllm.logger import init_logger
logger = init_logger(__name__)
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
# dtype = inputs.dtype
frame, dim = inputs.shape
means = cmvn[0:1, :dim]
vars = cmvn[1:2, :dim]
inputs += means.to(device)
inputs *= vars.to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
# LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
feat_dim = inputs.shape[-1]
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if num_padding > 0:
num_padding = (
(2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n)
/ 2
* (T_lfr - last_idx)
)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
LFR_outputs = inputs.as_strided(sizes, strides)
return LFR_outputs.clone().type(torch.float32)
def load_cmvn(cmvn_file):
with open(cmvn_file, encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR."""
def __init__(
self,
cmvn_file: str = "null",
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
**kwargs,
):
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
def forward(
self,
input: torch.Tensor,
input_lengths,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=min(self.frame_length, waveform_length / self.fs * 1000),
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
snip_edges=self.snip_edges,
)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
def forward_fbank(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
def forward_lfr_cmvn(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
class FunASRFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a FunASR feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_
utils.SequenceFeatureExtractor`] which contains most of the main
methods. Users should refer to this superclass for more information
regarding those methods.
This class extracts mel-filter bank features from raw speech using a custom
numpy implementation of the `Short Time Fourier Transform` which should
match pytorch's `torch.stft` equivalent.
Args:
feature_size (`int`, *optional*, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized
expressed in hertz (Hz).
hop_length (`int`, *optional*, defaults to 160):
Length of the overlapping windows for the STFT used to obtain the
Mel Frequency coefficients.
chunk_length (`int`, *optional*, defaults to 30):
The maximum number of chunks of `sampling_rate` samples used to
trim and pad longer or shorter audio sequences.
n_fft (`int`, *optional*, defaults to 400):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 0.0001 to add dithering with a normal distribution centered
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range
of raw_speech). The value 0.0 means no dithering.
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
the high log_mel_fbank values for signals with hard-zero sections,
when VAD cutoff is present in the signal.
"""
model_input_names = ["input_features"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
dither=0.0,
return_attention_mask=False,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.frontend_conf = kwargs.get("frontend_conf", {})
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.dither = dither
def extract_fbank(
self, data, data_len=None, data_type: str = "sound", frontend=None, **kwargs
):
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True)
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)
def __call__(
self,
raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
truncation: bool = True,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_attention_mask: bool | None = None,
padding: str | None = "max_length",
max_length: int | None = None,
sampling_rate: int | None = None,
do_normalize: bool | None = None,
device: str | None = "cpu",
return_token_timestamps: bool | None = None,
**kwargs,
) -> BatchFeature:
is_batched = isinstance(raw_speech, (list, tuple)) and (
isinstance(raw_speech[0], (np.ndarray, tuple, list))
)
if is_batched:
raw_speech = [
np.asarray([speech], dtype=np.float32).T for speech in raw_speech
]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
np.float64
):
raw_speech = raw_speech.astype(np.float32)
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_features": raw_speech})
padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask or do_normalize,
)
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
self.frontend = WavFrontend(**self.frontend_conf)
input_features, speech_lengths = self.extract_fbank(
input_features[0],
data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend,
is_final=True,
)
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len = (olens - 1) // 2 + 1
if isinstance(input_features[0], list):
padded_inputs["input_features"] = [
np.asarray(feature, dtype=np.float32) for feature in input_features
]
else:
padded_inputs["input_features"] = input_features
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
padded_inputs["speech_lengths"] = speech_lengths
padded_inputs["fake_token_len"] = fake_token_len
return padded_inputs
class FunASRProcessor(ProcessorMixin):
r"""
Constructs a FunASR processor which wraps a FunASR feature extractor and
a FunASR tokenizer into a single processor.
[`FunASRProcessor`] offers all the functionalities of
[`FunASRFeatureExtractor`] and [`Qwen2Tokenizer`]. See the
[`~FunASRProcessor.__call__`] and [`~FunASRProcessor.decode`] for more
information.
Args:
feature_extractor (`FunASRFeatureExtractor`): An instance of
[`FunASRFeatureExtractor`].
The feature extractor is a required input.
tokenizer (`Qwen2Tokenizer`):
An instance of [`Qwen2Tokenizer`]. The tokenizer is a required
input.
"""
feature_extractor_class = "FunASRFeatureExtractor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
feature_extractor,
tokenizer,
audio_token="<|AUDIO|>",
):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
self.audio_token = (
tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
)
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
return self.tokenizer.get_decoder_prompt_ids(
task=task, language=language, no_timestamps=no_timestamps
)
def __call__(self, *args, **kwargs):
"""
Forwards the `audio` argument to FunASRFeatureExtractor's
[`~FunASRFeatureExtractor.__call__`] and the `text` argument to
[`~Qwen2Tokenizer.__call__`]. Please refer to the docstring of the
above two methods for more information.
"""
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if text is None:
raise ValueError("You need to specify `text` input to process.")
elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)
if audio is not None:
# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audio) is np.ndarray else len(audio)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" # noqa: E501
)
inputs = self.feature_extractor(
audio, *args, sampling_rate=sampling_rate, **kwargs
)
expanded_text = []
for sample in text:
replace_str = []
while self.audio_token in sample:
num_audio_tokens = inputs["fake_token_len"].item()
expanded_audio_token = self.audio_token * num_audio_tokens
replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)
while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def get_prompt_ids(self, text: str, return_tensors="np"):
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
AutoFeatureExtractor.register("FunASRFeatureExtractor", FunASRFeatureExtractor)
AutoProcessor.register("FunASRProcessor", FunASRProcessor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment