Unverified Commit 967146e7 authored by PatchyTIS's avatar PatchyTIS Committed by GitHub
Browse files

[model] support FireRedLID (#39290)


Signed-off-by: default avatarPatchouliTaisa <patchychen@tencent.com>
Co-authored-by: default avatarPatchouliTaisa <patchychen@tencent.com>
parent 8e8a3bec
...@@ -661,6 +661,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. ...@@ -661,6 +661,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| ------------ | ------ | ----------------- | -------------------- | ------------------------- | | ------------ | ------ | ----------------- | -------------------- | ------------------------- |
| `CohereAsrForConditionalGeneration` | Cohere-Transcribe | `CohereLabs/cohere-transcribe-03-2026` | | | | `CohereAsrForConditionalGeneration` | Cohere-Transcribe | `CohereLabs/cohere-transcribe-03-2026` | | |
| `FireRedASR2ForConditionalGeneration` | FireRedASR2 | `allendou/FireRedASR2-LLM-vllm`, etc. | | | | `FireRedASR2ForConditionalGeneration` | FireRedASR2 | `allendou/FireRedASR2-LLM-vllm`, etc. | | |
| `FireRedLIDForConditionalGeneration` | FireRedLID | `PatchyTisa/FireRedLID-vllm`, etc. | | |
| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | | | `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | |
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
......
...@@ -537,9 +537,30 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: ...@@ -537,9 +537,30 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
) )
# FireRedLID
def run_fireredlid(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "FireRedLID only supports single audio input per prompt"
model_name = "PatchyTisa/FireRedLID-vllm"
prompt = "<sos>"
engine_args = EngineArgs(
model=model_name,
max_model_len=8,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
model_example_map = { model_example_map = {
"audioflamingo3": run_audioflamingo3, "audioflamingo3": run_audioflamingo3,
"cohere_asr": run_cohere_asr, "cohere_asr": run_cohere_asr,
"fireredlid": run_fireredlid,
"funaudiochat": run_funaudiochat, "funaudiochat": run_funaudiochat,
"gemma3n": run_gemma3n, "gemma3n": run_gemma3n,
"glmasr": run_glmasr, "glmasr": run_glmasr,
......
...@@ -55,7 +55,91 @@ def run_whisper(): ...@@ -55,7 +55,91 @@ def run_whisper():
) )
def run_fireredasr2():
"""
FireRedASR2 – Automatic Speech Recognition model.
This model uses a Conformer encoder + Qwen2 LLM decoder architecture
for speech-to-text transcription. Audio is passed via the implicit
prompt format with the ``<|AUDIO|>`` placeholder token.
"""
engine_args = EngineArgs(
model="allendou/FireRedASR2-LLM-vllm",
max_model_len=448,
max_num_seqs=16,
limit_mm_per_prompt={"audio": 1},
)
prompt_str = (
"<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n"
)
prompts = [
{ # Implicit prompt with audio
"prompt": prompt_str,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Another audio sample
"prompt": prompt_str,
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_fireredlid():
"""
FireRedLID – Language Identification model.
This encoder-decoder model identifies the spoken language of an audio
clip. It outputs at most 2 tokens representing the detected language
(e.g. "en", "zh mandarin").
"""
engine_args = EngineArgs(
model="PatchyTisa/FireRedLID-vllm",
max_model_len=8,
max_num_seqs=16,
limit_mm_per_prompt={"audio": 1},
)
prompts = [
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
"decoder_prompt": "<sos>",
},
{ # Another audio sample
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<sos>",
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
model_example_map = { model_example_map = {
"fireredasr2": run_fireredasr2,
"fireredlid": run_fireredlid,
"whisper": run_whisper, "whisper": run_whisper,
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Language Identification (LID) demo using the FireRedLID model on vLLM.
FireRedLID is an audio encoder-decoder model that identifies the spoken
language of an audio clip. Unlike ASR models that output full transcriptions,
FireRedLID outputs at most 2 tokens representing the detected language
(e.g. "en", "zh mandarin").
Start the vLLM server:
vllm serve PatchyTisa/FireRedLID-vllm
Then run this script:
# Use the built-in sample audio
python examples/online_serving/openai_lid_client.py
# Use your own audio file(s)
python examples/online_serving/openai_lid_client.py \
--audio_paths audio_en.wav audio_zh.wav audio_fr.wav
# Batch-identify multiple files in one run
python examples/online_serving/openai_lid_client.py \
--audio_paths /path/to/dir/*.wav
Requirements:
- vLLM with audio support
- openai Python SDK
- kaldi_native_fbank (pulled in by the model)
"""
import argparse
import json
import os
from openai import OpenAI
from vllm.assets.audio import AudioAsset
# ──────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────
def identify_language(
audio_path: str,
client: OpenAI,
model: str,
) -> str:
"""
Send a single audio file to the vLLM transcription endpoint and return
the detected language tag.
FireRedLID re-uses the OpenAI-compatible ``/v1/audio/transcriptions``
endpoint. The "transcription" it returns is actually the language label
(e.g. ``"en"`` or ``"zh mandarin"``).
"""
with open(audio_path, "rb") as f:
result = client.audio.transcriptions.create(
file=f,
model=model,
response_format="json",
temperature=0.0,
)
return result.text.strip()
def identify_language_raw(
audio_path: str,
model: str,
api_base: str,
) -> str:
"""
Same as :func:`identify_language` but uses raw HTTP so that the demo
works without the ``openai`` SDK (useful for quick debugging).
"""
import requests
url = f"{api_base}/audio/transcriptions"
with open(audio_path, "rb") as f:
files = {"file": (os.path.basename(audio_path), f)}
data = {
"model": model,
"response_format": "json",
}
resp = requests.post(url, files=files, data=data)
resp.raise_for_status()
return resp.json()["text"].strip()
def identify_language_streaming(
audio_path: str,
model: str,
api_base: str,
) -> str:
"""
Streaming variant – demonstrates the streaming transcription endpoint.
For a 1-2 token output the stream finishes almost instantly, but this
shows that the API path works end-to-end.
"""
import requests
url = f"{api_base}/audio/transcriptions"
with open(audio_path, "rb") as f:
files = {"file": (os.path.basename(audio_path), f)}
data = {
"stream": "true",
"model": model,
"response_format": "json",
}
response = requests.post(url, files=files, data=data, stream=True)
response.raise_for_status()
tokens: list[str] = []
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if not chunk:
continue
payload = json.loads(chunk[len("data: ") :].decode("utf-8"))
choice = payload["choices"][0]
delta = choice.get("delta", {}).get("content", "")
if delta:
tokens.append(delta)
if choice.get("finish_reason") is not None:
break
return "".join(tokens).strip()
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main(args: argparse.Namespace) -> None:
api_base = args.api_base.rstrip("/")
client = OpenAI(api_key="EMPTY", base_url=api_base)
model = client.models.list().data[0].id
print(f"Model : {model}")
print(f"Server: {api_base}\n")
# Resolve audio paths ------------------------------------------------
if args.audio_paths:
audio_paths = args.audio_paths
else:
# Fall back to the built-in vLLM sample audios (both are English).
audio_paths = [
str(AudioAsset("mary_had_lamb").get_local_path()),
str(AudioAsset("winning_call").get_local_path()),
]
# Run LID for each file ----------------------------------------------
print(f"{'Audio File':<50} {'Language (sync)':<20} {'Language (stream)'}")
print("-" * 90)
for path in audio_paths:
basename = os.path.basename(path)
# 1) Synchronous via OpenAI SDK
lang_sync = identify_language(path, client, model)
# 2) Streaming via raw HTTP
lang_stream = identify_language_streaming(path, model, api_base)
print(f"{basename:<50} {lang_sync:<20} {lang_stream}")
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="FireRedLID – Language Identification demo via vLLM",
)
parser.add_argument(
"--audio_paths",
nargs="+",
default=None,
help=(
"One or more audio files to identify. "
"If omitted, uses vLLM's built-in sample audios."
),
)
parser.add_argument(
"--api_base",
type=str,
default="http://localhost:8000/v1",
help="vLLM API base URL (default: http://localhost:8000/v1)",
)
args = parser.parse_args()
main(args)
...@@ -820,6 +820,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -820,6 +820,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"FireRedASR2ForConditionalGeneration": _HfExamplesInfo( "FireRedASR2ForConditionalGeneration": _HfExamplesInfo(
"allendou/FireRedASR2-LLM-vllm", "allendou/FireRedASR2-LLM-vllm",
), ),
"FireRedLIDForConditionalGeneration": _HfExamplesInfo(
"PatchyTisa/FireRedLID-vllm",
),
"FunASRForConditionalGeneration": _HfExamplesInfo( "FunASRForConditionalGeneration": _HfExamplesInfo(
"allendou/Fun-ASR-Nano-2512-vllm", "allendou/Fun-ASR-Nano-2512-vllm",
), ),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared Conformer encoder components for FireRedASR2 and FireRedLID.
Both models use the same Conformer-based audio encoder architecture
(Conv2dSubsampling → RelPositionalEncoding → N × RelPosEmbConformerBlock).
This module factors out the common building blocks to avoid duplication.
"""
import torch
import torch.nn.functional as F
from torch import nn
from vllm.model_executor.layers.linear import ReplicatedLinear
class Conv2dSubsampling(nn.Module):
def __init__(self, idim: int, d_model: int, out_channels: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, out_channels, 3, 2),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 2),
nn.ReLU(),
)
subsample_idim = ((idim - 1) // 2 - 1) // 2
self.out = ReplicatedLinear(
input_size=out_channels * subsample_idim,
output_size=d_model,
bias=True,
)
self.subsampling = 4
left_context = right_context = 3 # both exclude current frame
self.context = left_context + 1 + right_context # 7
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1)
x = self.conv(x)
N, C, T, D = x.size()
x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
input_lengths = mask[:, -1, :].sum(dim=-1)
return x, input_lengths, mask
class Swish(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class RelPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* -(torch.log(torch.tensor(10000.0)).item() / d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
self.pe = torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Tmax = 2 * max_len - 1
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
return pos_emb
class ConformerFeedForward(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.pre_layer_norm = nn.LayerNorm(d_model)
self.linear_expand = ReplicatedLinear(
input_size=d_model,
output_size=d_model * 4,
bias=True,
)
self.nonlinear = Swish()
self.linear_project = ReplicatedLinear(
input_size=d_model * 4,
output_size=d_model,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.pre_layer_norm(x)
x, _ = self.linear_expand(x)
x = self.nonlinear(x)
x, _ = self.linear_project(x)
return x + residual
class EncoderMultiHeadAttention(nn.Module):
def __init__(self, n_head: int, d_model: int):
super().__init__()
assert d_model % n_head == 0
self.n_head = n_head
self.d_k = d_model // n_head
self.d_v = self.d_k
self.w_qs = ReplicatedLinear(d_model, n_head * self.d_k, bias=False)
self.w_ks = ReplicatedLinear(d_model, n_head * self.d_k, bias=False)
self.w_vs = ReplicatedLinear(d_model, n_head * self.d_v, bias=False)
self.layer_norm_q = nn.LayerNorm(d_model)
self.layer_norm_k = nn.LayerNorm(d_model)
self.layer_norm_v = nn.LayerNorm(d_model)
self.fc = ReplicatedLinear(n_head * self.d_v, d_model, bias=False)
def forward_qkv(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
q = self.layer_norm_q(q)
k = self.layer_norm_k(k)
v = self.layer_norm_v(v)
q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_output(
self,
output: torch.Tensor,
residual: torch.Tensor,
sz_b: int,
len_q: int,
) -> torch.Tensor:
output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
fc_out, _ = self.fc(output)
return fc_out + residual
def forward_attention(
self,
attn: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if mask is not None:
mask = mask.unsqueeze(1)
mask = mask.eq(0)
attn = attn.masked_fill(mask, -float("inf"))
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
return output, attn
class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
def __init__(self, n_head: int, d_model: int):
super().__init__(n_head, d_model)
d_k = d_model // n_head
self.scale = 1.0 / (d_k**0.5)
self.linear_pos = ReplicatedLinear(d_model, n_head * d_k, bias=False)
self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k]))
self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k]))
def _rel_shift(self, x):
N, H, T1, T2 = x.size()
zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(N, H, T2 + 1, T1)
x = x_padded[:, :, 1:].view_as(x)
x = x[:, :, :, : x.size(-1) // 2 + 1]
return x
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pos_emb: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
sz_b, len_q = q.size(0), q.size(1)
residual = q
q, k, v = self.forward_qkv(q, k, v)
q = q.transpose(1, 2)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k)
p = p.transpose(1, 2)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self._rel_shift(matrix_bd)
attn_scores = matrix_ac + matrix_bd
attn_scores.mul_(self.scale)
output, attn = self.forward_attention(attn_scores, v, mask=mask)
output = self.forward_output(output, residual, sz_b, len_q)
return output, attn
class ConformerConvolution(nn.Module):
def __init__(self, d_model: int, kernel_size: int = 33):
super().__init__()
assert kernel_size % 2 == 1
self.pre_layer_norm = nn.LayerNorm(d_model)
self.pointwise_conv1 = nn.Conv1d(
d_model, d_model * 4, kernel_size=1, bias=False
)
self.padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
d_model * 2,
d_model * 2,
kernel_size,
stride=1,
padding=self.padding,
groups=d_model * 2,
bias=False,
)
self.batch_norm = nn.LayerNorm(d_model * 2)
self.swish = Swish()
self.pointwise_conv2 = nn.Conv1d(
d_model * 2, d_model, kernel_size=1, bias=False
)
def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
residual = x
out = self.pre_layer_norm(x)
out = out.transpose(1, 2)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = self.pointwise_conv1(out)
out = F.glu(out, dim=1)
out = self.depthwise_conv(out)
out = out.transpose(1, 2)
out = self.swish(self.batch_norm(out))
out = out.transpose(1, 2)
out = self.pointwise_conv2(out)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = out.transpose(1, 2)
return out + residual
class RelPosEmbConformerBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, kernel_size: int = 33):
super().__init__()
self.ffn1 = ConformerFeedForward(d_model)
self.mhsa = RelPosMultiHeadAttention(n_head, d_model)
self.conv = ConformerConvolution(d_model, kernel_size)
self.ffn2 = ConformerFeedForward(d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
pos_emb: torch.Tensor,
slf_attn_mask: torch.Tensor | None = None,
pad_mask: torch.Tensor | None = None,
) -> torch.Tensor:
out = 0.5 * x + 0.5 * self.ffn1(x)
out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
out = self.conv(out, pad_mask)
out = 0.5 * out + 0.5 * self.ffn2(out)
out = self.layer_norm(out)
return out
class ConformerEncoder(nn.Module):
"""
Conformer encoder shared by FireRedASR2 and FireRedLID.
"""
def __init__(
self,
idim: int,
n_layers_enc: int,
n_head: int,
d_model: int,
kernel_size: int = 33,
pe_maxlen: int = 5000,
):
super().__init__()
self.odim = d_model
self.input_preprocessor = Conv2dSubsampling(idim, d_model)
self.positional_encoding = RelPositionalEncoding(d_model, max_len=pe_maxlen)
self.layer_stack = nn.ModuleList()
for _ in range(n_layers_enc):
block = RelPosEmbConformerBlock(d_model, n_head, kernel_size)
self.layer_stack.append(block)
def forward(
self,
padded_input: torch.Tensor,
input_lengths: torch.Tensor,
pad: bool = True,
):
if pad:
padded_input = F.pad(
padded_input,
(0, 0, 0, self.input_preprocessor.context - 1),
"constant",
0.0,
)
src_mask = self.padding_position_is_0(padded_input, input_lengths)
embed_output, input_lengths, src_mask = self.input_preprocessor(
padded_input, src_mask
)
enc_output = embed_output
pos_emb = self.positional_encoding(embed_output)
for enc_layer in self.layer_stack:
enc_output = enc_layer(
enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask
)
return enc_output, input_lengths, src_mask
def padding_position_is_0(
self, padded_input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
N, T = padded_input.size()[:2]
# Use broadcasting instead of a Python loop for efficiency.
positions = torch.arange(T, device=padded_input.device).unsqueeze(0)
mask = (positions < input_lengths.unsqueeze(1)).to(torch.uint8)
return mask.unsqueeze(1)
...@@ -6,7 +6,6 @@ from typing import Annotated, Literal, cast ...@@ -6,7 +6,6 @@ from typing import Annotated, Literal, cast
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import ( from transformers import (
BatchFeature, BatchFeature,
...@@ -45,6 +44,7 @@ from vllm.transformers_utils.processors.fireredasr2 import ( ...@@ -45,6 +44,7 @@ from vllm.transformers_utils.processors.fireredasr2 import (
) )
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .conformer_encoder import ConformerEncoder
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsMultiModal, SupportsMultiModal,
...@@ -84,352 +84,6 @@ class FireRedASR2AudioInputs(TensorSchema): ...@@ -84,352 +84,6 @@ class FireRedASR2AudioInputs(TensorSchema):
] ]
class Swish(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class Conv2dSubsampling(nn.Module):
def __init__(self, idim: int, d_model: int, out_channels: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, out_channels, 3, 2),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 2),
nn.ReLU(),
)
subsample_idim = ((idim - 1) // 2 - 1) // 2
self.out = ReplicatedLinear(
input_size=out_channels * subsample_idim,
output_size=d_model,
bias=True,
)
self.subsampling = 4
left_context = right_context = 3 # both exclude current frame
self.context = left_context + 1 + right_context # 7
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1)
x = self.conv(x)
N, C, T, D = x.size()
x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
input_lengths = mask[:, -1, :].sum(dim=-1)
return x, input_lengths, mask
class RelPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* -(torch.log(torch.tensor(10000.0)).item() / d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
self.pe = torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Tmax = 2 * max_len - 1
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
return pos_emb
class ConformerFeedForward(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.pre_layer_norm = nn.LayerNorm(d_model)
self.linear_expand = ReplicatedLinear(
input_size=d_model,
output_size=d_model * 4,
bias=True,
)
self.nonlinear = Swish()
self.linear_project = ReplicatedLinear(
input_size=d_model * 4,
output_size=d_model,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.pre_layer_norm(x)
x, _ = self.linear_expand(x)
x = self.nonlinear(x)
x, _ = self.linear_project(x)
output = x + residual
return output
class EncoderMultiHeadAttention(nn.Module):
def __init__(self, n_head: int, d_model: int):
super().__init__()
assert d_model % n_head == 0
self.n_head = n_head
self.d_k = d_model // n_head
self.d_v = self.d_k
self.w_qs = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_k, bias=False
)
self.w_ks = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_k, bias=False
)
self.w_vs = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_v, bias=False
)
self.layer_norm_q = nn.LayerNorm(d_model)
self.layer_norm_k = nn.LayerNorm(d_model)
self.layer_norm_v = nn.LayerNorm(d_model)
self.fc = ReplicatedLinear(
input_size=n_head * self.d_v, output_size=d_model, bias=False
)
def forward_qkv(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
q = self.layer_norm_q(q)
k = self.layer_norm_k(k)
v = self.layer_norm_v(v)
q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_output(
self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int
) -> torch.Tensor:
output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
fc_out, _ = self.fc(output)
output = fc_out
output = output + residual
return output
def forward_attention(
self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if mask is not None:
mask = mask.unsqueeze(1)
mask = mask.eq(0)
attn = attn.masked_fill(mask, -float("inf"))
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(attn, dim=-1)
d_attn = attn
output = torch.matmul(d_attn, v)
return output, attn
class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
def __init__(self, n_head: int, d_model: int):
super().__init__(n_head, d_model)
d_k = d_model // n_head
self.scale = 1.0 / (d_k**0.5)
self.linear_pos = ReplicatedLinear(
input_size=d_model, output_size=n_head * d_k, bias=False
)
self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k]))
self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k]))
def _rel_shift(self, x):
N, H, T1, T2 = x.size()
zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(N, H, T2 + 1, T1)
x = x_padded[:, :, 1:].view_as(x)
x = x[:, :, :, : x.size(-1) // 2 + 1]
return x
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pos_emb: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
sz_b, len_q = q.size(0), q.size(1)
residual = q
q, k, v = self.forward_qkv(q, k, v)
q = q.transpose(1, 2)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k)
p = p.transpose(1, 2)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self._rel_shift(matrix_bd)
attn_scores = matrix_ac + matrix_bd
attn_scores.mul_(self.scale)
output, attn = self.forward_attention(attn_scores, v, mask=mask)
output = self.forward_output(output, residual, sz_b, len_q)
return output, attn
class ConformerConvolution(nn.Module):
def __init__(self, d_model: int, kernel_size: int = 33):
super().__init__()
assert kernel_size % 2 == 1
self.pre_layer_norm = nn.LayerNorm(d_model)
self.pointwise_conv1 = nn.Conv1d(
d_model, d_model * 4, kernel_size=1, bias=False
)
self.padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
d_model * 2,
d_model * 2,
kernel_size,
stride=1,
padding=self.padding,
groups=d_model * 2,
bias=False,
)
self.batch_norm = nn.LayerNorm(d_model * 2)
self.swish = Swish()
self.pointwise_conv2 = nn.Conv1d(
d_model * 2, d_model, kernel_size=1, bias=False
)
def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
residual = x
out = self.pre_layer_norm(x)
out = out.transpose(1, 2)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = self.pointwise_conv1(out)
out = F.glu(out, dim=1)
out = self.depthwise_conv(out)
out = out.transpose(1, 2)
out = self.swish(self.batch_norm(out))
out = out.transpose(1, 2)
out = self.pointwise_conv2(out)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = out.transpose(1, 2)
return out + residual
class RelPosEmbConformerBlock(nn.Module):
def __init__(self, d_model, n_head, kernel_size=33):
super().__init__()
self.ffn1 = ConformerFeedForward(d_model)
self.mhsa = RelPosMultiHeadAttention(n_head, d_model)
self.conv = ConformerConvolution(d_model, kernel_size)
self.ffn2 = ConformerFeedForward(d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
pos_emb: torch.Tensor,
slf_attn_mask: torch.Tensor | None = None,
pad_mask: torch.Tensor | None = None,
) -> torch.Tensor:
out = 0.5 * x + 0.5 * self.ffn1(x)
out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
out = self.conv(out, pad_mask)
out = 0.5 * out + 0.5 * self.ffn2(out)
out = self.layer_norm(out)
return out
class ConformerEncoder(nn.Module):
def __init__(
self,
idim: int,
n_layers_enc: int,
n_head: int,
d_model: int,
kernel_size: int = 33,
pe_maxlen: int = 5000,
):
super().__init__()
self.odim = d_model
self.input_preprocessor = Conv2dSubsampling(idim, d_model)
self.positional_encoding = RelPositionalEncoding(d_model)
self.layer_stack = nn.ModuleList()
for _ in range(n_layers_enc):
block = RelPosEmbConformerBlock(d_model, n_head, kernel_size)
self.layer_stack.append(block)
def forward(
self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True
):
if pad:
padded_input = F.pad(
padded_input,
(0, 0, 0, self.input_preprocessor.context - 1),
"constant",
0.0,
)
src_mask = self.padding_position_is_0(padded_input, input_lengths)
embed_output, input_lengths, src_mask = self.input_preprocessor(
padded_input, src_mask
)
enc_output = embed_output
pos_emb = self.positional_encoding(embed_output)
enc_outputs = []
for enc_layer in self.layer_stack:
enc_output = enc_layer(
enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask
)
enc_outputs.append(enc_output)
return enc_output, input_lengths, src_mask
def padding_position_is_0(
self, padded_input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
N, T = padded_input.size()[:2]
mask = torch.ones((N, T)).to(padded_input.device)
for i in range(N):
mask[i, input_lengths[i] :] = 0
mask = mask.unsqueeze(dim=1)
return mask.to(torch.uint8)
class FireRedASR2Adapter(nn.Module): class FireRedASR2Adapter(nn.Module):
def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2): def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2):
super().__init__() super().__init__()
......
This diff is collapsed.
...@@ -380,6 +380,10 @@ _MULTIMODAL_MODELS = { ...@@ -380,6 +380,10 @@ _MULTIMODAL_MODELS = {
"FireRedASR2ForConditionalGeneration", "FireRedASR2ForConditionalGeneration",
), ),
"FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),
"FireRedLIDForConditionalGeneration": (
"fireredlid",
"FireRedLIDForConditionalGeneration",
),
"FunAudioChatForConditionalGeneration": ( "FunAudioChatForConditionalGeneration": (
"funaudiochat", "funaudiochat",
"FunAudioChatForConditionalGeneration", "FunAudioChatForConditionalGeneration",
......
...@@ -90,6 +90,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -90,6 +90,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config", deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config", deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig", flex_olmo="FlexOlmoConfig",
fireredlid="FireRedLIDConfig",
funaudiochat="FunAudioChatConfig", funaudiochat="FunAudioChatConfig",
hunyuan_vl="HunYuanVLConfig", hunyuan_vl="HunYuanVLConfig",
isaac="IsaacConfig", isaac="IsaacConfig",
......
...@@ -28,6 +28,7 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -28,6 +28,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2", "DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2",
"DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr",
"EAGLEConfig": "vllm.transformers_utils.configs.eagle", "EAGLEConfig": "vllm.transformers_utils.configs.eagle",
"FireRedLIDConfig": "vllm.transformers_utils.configs.fireredlid",
"FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo", "FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo",
"FunAudioChatConfig": "vllm.transformers_utils.configs.funaudiochat", "FunAudioChatConfig": "vllm.transformers_utils.configs.funaudiochat",
"FunAudioChatAudioEncoderConfig": "vllm.transformers_utils.configs.funaudiochat", "FunAudioChatAudioEncoderConfig": "vllm.transformers_utils.configs.funaudiochat",
...@@ -88,6 +89,7 @@ __all__ = [ ...@@ -88,6 +89,7 @@ __all__ = [
"DotsOCRConfig", "DotsOCRConfig",
"EAGLEConfig", "EAGLEConfig",
"FlexOlmoConfig", "FlexOlmoConfig",
"FireRedLIDConfig",
"FunAudioChatConfig", "FunAudioChatConfig",
"FunAudioChatAudioEncoderConfig", "FunAudioChatAudioEncoderConfig",
"HunYuanVLConfig", "HunYuanVLConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import contextlib
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
class FireRedLIDConfig(PretrainedConfig):
"""Minimal config class for native vLLM FireRedLID support."""
model_type = "fireredlid"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size: int = 120,
lid_odim: int = 120,
idim: int = 80,
d_model: int = 1280,
n_head: int = 20,
n_layers_enc: int = 16,
n_layers_lid_dec: int = 6,
kernel_size: int = 33,
residual_dropout: float = 0.05,
dropout_rate: float = 0.05,
pe_maxlen: int = 5000,
pad_token_id: int = 2,
bos_token_id: int = 3,
eos_token_id: int = 4,
decoder_start_token_id: int = 3,
tie_word_embeddings: bool = True,
is_encoder_decoder: bool = True,
architectures: list[str] | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.lid_odim = lid_odim
self.idim = idim
self.d_model = d_model
self.hidden_size = d_model
self.n_head = n_head
self.num_attention_heads = n_head
self.n_layers_enc = n_layers_enc
self.encoder_layers = n_layers_enc
self.n_layers_lid_dec = n_layers_lid_dec
self.decoder_layers = n_layers_lid_dec
self.num_hidden_layers = n_layers_lid_dec
self.kernel_size = kernel_size
self.residual_dropout = residual_dropout
self.dropout_rate = dropout_rate
self.pe_maxlen = pe_maxlen
self.tie_word_embeddings = tie_word_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.architectures = architectures or ["FireRedLIDForConditionalGeneration"]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
tie_word_embeddings=tie_word_embeddings,
is_encoder_decoder=is_encoder_decoder,
architectures=self.architectures,
**kwargs,
)
with contextlib.suppress(ValueError):
AutoConfig.register(FireRedLIDConfig.model_type, FireRedLIDConfig)
...@@ -16,6 +16,7 @@ __all__ = [ ...@@ -16,6 +16,7 @@ __all__ = [
"CohereASRProcessor", "CohereASRProcessor",
"DeepseekVLV2Processor", "DeepseekVLV2Processor",
"FireRedASR2Processor", "FireRedASR2Processor",
"FireRedLIDProcessor",
"FunASRProcessor", "FunASRProcessor",
"GLM4VProcessor", "GLM4VProcessor",
"H2OVLProcessor", "H2OVLProcessor",
...@@ -44,6 +45,7 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -44,6 +45,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr", "CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FireRedLIDProcessor": "vllm.transformers_utils.processors.fireredlid",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr", "FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v", "GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"H2OVLProcessor": "vllm.transformers_utils.processors.h2ovl", "H2OVLProcessor": "vllm.transformers_utils.processors.h2ovl",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
FireRedLID feature extractor and processor.
The FeatureExtractor handles:
- Raw waveform → 80-dim log-mel filterbank (via kaldi_native_fbank)
- CMVN normalization (means / inverse_std_variences from preprocessor_config)
- Padding + length tracking
The Processor wraps the FeatureExtractor and a tokenizer.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
AutoFeatureExtractor,
BatchFeature,
)
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.processing_utils import ProcessorMixin
from transformers.utils import TensorType
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import kaldi_native_fbank as knf
else:
knf = LazyLoader("knf", globals(), "kaldi_native_fbank")
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Helpers (shared with FireRedASR2 processor)
# ---------------------------------------------------------------------------
class CMVN:
def __init__(self, dim, means, inverse_std_variences):
self.dim = dim
self.means = np.array(means)
self.inverse_std_variences = np.array(inverse_std_variences)
def __call__(self, x):
assert x.shape[-1] == self.dim, "CMVN dim mismatch"
out = x - self.means
out = out * self.inverse_std_variences
return out
class KaldifeatFbank:
def __init__(
self,
num_mel_bins: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
dither: float = 0.0,
):
self.dither = dither
opts = knf.FbankOptions()
opts.frame_opts.dither = dither
opts.mel_opts.num_bins = num_mel_bins
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
def __call__(self, sample_rate, wav_np, is_train=False):
dither = self.dither if is_train else 0.0
self.opts.frame_opts.dither = dither
fbank = knf.OnlineFbank(self.opts)
fbank.accept_waveform(sample_rate, wav_np.tolist())
feat = []
for i in range(fbank.num_frames_ready):
feat.append(fbank.get_frame(i))
if len(feat) == 0:
return np.zeros((0, self.opts.mel_opts.num_bins))
return np.vstack(feat)
# ---------------------------------------------------------------------------
# Feature Extractor
# ---------------------------------------------------------------------------
class FireRedLIDFeatureExtractor(SequenceFeatureExtractor):
"""
Extracts 80-dim log-mel filterbank features from raw waveforms,
applies CMVN, and returns padded feature tensors with lengths.
Also computes ``fake_token_lengths`` — the actual encoder output
length for each audio — so that vLLM can allocate the correct
number of cross-attention KV cache slots.
"""
model_input_names = ["input_features"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
chunk_length=30,
padding_value=0.0,
return_attention_mask=False,
dim=80,
means=None,
inverse_std_variences=None,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
left_context=3,
right_context=3,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.chunk_length = chunk_length
self.dim = dim
self.means = means
self.inverse_std_variences = inverse_std_variences
self.num_mel_bins = num_mel_bins
self.frame_length = frame_length
self.frame_shift = frame_shift
self.dither = dither
self.sampling_rate = sampling_rate
self.context = left_context + 1 + right_context
def __call__(
self,
raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
truncation: bool = True,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_attention_mask: bool | None = None,
padding: str | None = "max_length",
max_length: int | None = None,
sampling_rate: int | None = None,
do_normalize: bool | None = None,
**kwargs,
) -> BatchFeature:
if sampling_rate is not None and sampling_rate != self.sampling_rate:
raise ValueError(
f"FireRedLIDFeatureExtractor expects sampling_rate="
f"{self.sampling_rate}, got {sampling_rate}."
)
# Initialize helpers
cmvn = CMVN(self.dim, self.means, self.inverse_std_variences)
fbank = KaldifeatFbank(
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
)
def padding_position_is_0(padded_input, input_lengths):
N, T = padded_input.size()[:2]
mask = torch.ones((N, T)).to(padded_input.device)
for i in range(N):
mask[i, input_lengths[i] :] = 0
mask = mask.unsqueeze(dim=1)
return mask.to(torch.uint8)
feats = []
speech_lengths = []
fake_token_lengths = []
for speech in raw_speech:
# vLLM loads audio via librosa (float32 in [-1,1]),
# but kaldi_native_fbank expects int16-scale values.
speech_scaled = speech * 32768
feat = fbank(self.sampling_rate, speech_scaled)
feat = cmvn(feat)
feat = torch.from_numpy(feat).float()
length = feat.size(0)
feats.append(feat)
speech_lengths.append(length)
# Compute the actual Conv2dSubsampling output length.
# This mirrors the mask logic in Conv2dSubsampling.forward:
# pad context frames, then mask[:, :, :-2:2][:, :, :-2:2].sum()
padded_input = F.pad(feat, (0, 0, 0, self.context - 1), "constant", 0.0)
src_mask = padding_position_is_0(
padded_input[None, :, :],
torch.tensor([length], dtype=torch.int32),
)
mask = src_mask[:, :, :-2:2][:, :, :-2:2]
enc_len = mask[:, -1, :].sum(dim=-1)
fake_token_len = torch.clamp(enc_len, min=1)
fake_token_lengths.append(fake_token_len)
if len(feats) == 0:
return BatchFeature()
# Pad to uniform length
max_feat_len = max(f.size(0) for f in feats)
padded = feats[0].new_zeros(len(feats), max_feat_len, feats[0].size(1))
for i, feat in enumerate(feats):
padded[i, : feat.size(0)] = feat
result = BatchFeature({"input_features": padded})
if return_tensors is not None:
result = result.convert_to_tensors(return_tensors)
result["speech_lengths"] = torch.tensor(speech_lengths, dtype=torch.long)
result["fake_token_lengths"] = torch.concat(fake_token_lengths)
return result
# ---------------------------------------------------------------------------
# Processor
# ---------------------------------------------------------------------------
class FireRedLIDProcessor(ProcessorMixin):
"""
Wraps FireRedLIDFeatureExtractor + a tokenizer.
"""
feature_extractor_class = "FireRedLIDFeatureExtractor"
tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast")
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is not None:
inputs = self.feature_extractor(
audio, *args, sampling_rate=sampling_rate, **kwargs
)
else:
inputs = BatchFeature()
if text is not None:
if isinstance(text, str):
text = [text]
encodings = self.tokenizer(text, **kwargs)
if audio is not None:
inputs["labels"] = encodings["input_ids"]
else:
return encodings
return inputs
# ---------------------------------------------------------------------------
# Registration
# ---------------------------------------------------------------------------
AutoFeatureExtractor.register("FireRedLIDFeatureExtractor", FireRedLIDFeatureExtractor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment