Commit b75857fb authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import io
import re
import wave
import gradio as gr
from fish_speech.utils.schema import ServeMessage, ServeTextPart, ServeVQPart
from .fish_e2e import FishE2EAgent, FishE2EEventType
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
class ChatState:
def __init__(self):
self.conversation = []
self.added_systext = False
self.added_sysaudio = False
def get_history(self):
results = []
for msg in self.conversation:
results.append({"role": msg.role, "content": self.repr_message(msg)})
# Process assistant messages to extract questions and update user messages
for i, msg in enumerate(results):
if msg["role"] == "assistant":
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
if match and i > 0 and results[i - 1]["role"] == "user":
# Update previous user message with extracted question
results[i - 1]["content"] += "\n" + match.group(1)
# Remove the Question/Answer format from assistant message
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
return results
def repr_message(self, msg: ServeMessage):
response = ""
for part in msg.parts:
if isinstance(part, ServeTextPart):
response += part.text
elif isinstance(part, ServeVQPart):
response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
return response
def clear_fn():
return [], ChatState(), None, None, None
async def process_audio_input(
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
):
if audio_input is None and not text_input:
raise gr.Error("No input provided")
agent = FishE2EAgent() # Create new agent instance for each request
# Convert audio input to numpy array
if isinstance(audio_input, tuple):
sr, audio_data = audio_input
elif text_input:
sr = 44100
audio_data = None
else:
raise gr.Error("Invalid audio format")
if isinstance(sys_audio_input, tuple):
sr, sys_audio_data = sys_audio_input
else:
sr = 44100
sys_audio_data = None
def append_to_chat_ctx(
part: ServeTextPart | ServeVQPart, role: str = "assistant"
) -> None:
if not state.conversation or state.conversation[-1].role != role:
state.conversation.append(ServeMessage(role=role, parts=[part]))
else:
state.conversation[-1].parts.append(part)
if state.added_systext is False and sys_text_input:
state.added_systext = True
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
if text_input:
append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
audio_data = None
result_audio = b""
async for event in agent.stream(
sys_audio_data,
audio_data,
sr,
1,
chat_ctx={
"messages": state.conversation,
"added_sysaudio": state.added_sysaudio,
},
):
if event.type == FishE2EEventType.USER_CODES:
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
elif event.type == FishE2EEventType.SPEECH_SEGMENT:
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
elif event.type == FishE2EEventType.TEXT_SEGMENT:
append_to_chat_ctx(ServeTextPart(text=event.text))
yield state.get_history(), None, None, None
yield state.get_history(), None, None, None
async def process_text_input(
sys_audio_input, sys_text_input, state: ChatState, text_input: str
):
async for event in process_audio_input(
sys_audio_input, sys_text_input, None, state, text_input
):
yield event
def create_demo():
with gr.Blocks() as demo:
state = gr.State(ChatState())
with gr.Row():
# Left column (70%) for chatbot and notes
with gr.Column(scale=7):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
height=600,
type="messages",
)
# notes = gr.Markdown(
# """
# # Fish Agent
# 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
# 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
# 3. Demo为早期灰度测试版本,推理速度尚待优化.
# # 特色
# 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
# 2. 模型可以使用reference audio控制说话音色.
# 3. 可以生成具有较强情感与韵律的音频.
# """
# )
notes = gr.Markdown(
"""
# Fish Agent
1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
3. The demo is an early alpha test version, the inference speed needs to be optimised.
# Features
1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
2. The model can use reference audio to control the speech timbre.
3. The model can generate speech with strong emotion.
"""
)
# Right column (30%) for controls
with gr.Column(scale=3):
sys_audio_input = gr.Audio(
sources=["upload"],
type="numpy",
label="Give a timbre for your assistant",
)
sys_text_input = gr.Textbox(
label="What is your assistant's role?",
value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
type="text",
)
audio_input = gr.Audio(
sources=["microphone"], type="numpy", label="Speak your message"
)
text_input = gr.Textbox(label="Or type your message", type="text")
output_audio = gr.Audio(
label="Assistant's Voice",
streaming=True,
autoplay=True,
interactive=False,
)
send_button = gr.Button("Send", variant="primary")
clear_button = gr.Button("Clear")
# Event handlers
audio_input.stop_recording(
process_audio_input,
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
send_button.click(
process_text_input,
inputs=[sys_audio_input, sys_text_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
text_input.submit(
process_text_input,
inputs=[sys_audio_input, sys_text_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
clear_button.click(
clear_fn,
inputs=[],
outputs=[chatbot, state, audio_input, output_audio, text_input],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
import onnxruntime
import torch
import torch.nn.functional as F
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from tools.vqgan.extract_vq import get_model
PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])
class Encoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.spec_transform.spectrogram.return_complex = False
def forward(self, audios):
mels = self.model.spec_transform(audios)
encoded_features = self.model.backbone(mels)
z = self.model.quantizer.downsample(encoded_features)
_, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
_, b, l, _ = indices.shape
return indices.permute(1, 0, 3, 2).long().view(b, -1, l)
class Decoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.head.training = False
self.model.head.checkpointing = False
def get_codes_from_indices(self, cur_index, indices):
_, quantize_dim, _ = indices.shape
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
if (
quantize_dim
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
):
assert (
self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
indices = F.pad(
indices,
(
0,
self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
- quantize_dim,
),
value=-1,
)
mask = indices == -1
indices = indices.masked_fill(mask, 0)
all_codes = torch.gather(
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
dim=2,
index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
)
all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
scales = (
self.model.quantizer.residual_fsq.rvqs[cur_index]
.scales.unsqueeze(1)
.unsqueeze(1)
)
all_codes = all_codes * scales
return all_codes
def get_output_from_indices(self, cur_index, indices):
codes = self.get_codes_from_indices(cur_index, indices)
codes_summed = codes.sum(dim=0)
return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
codes_summed
)
def forward(self, indices) -> torch.Tensor:
batch_size, _, length = indices.shape
dims = self.model.quantizer.residual_fsq.dim
groups = self.model.quantizer.residual_fsq.groups
dim_per_group = dims // groups
# indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)
# z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
z_q = torch.empty((batch_size, length, dims))
for i in range(groups):
z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
self.get_output_from_indices(i, indices[i])
)
z = self.model.quantizer.upsample(z_q.transpose(1, 2))
x = self.model.head(z)
return x
def main(firefly_gan_vq_path, llama_path, export_prefix):
GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu")
enc = Encoder(GanModel)
dec = Decoder(GanModel)
audio_example = torch.randn(1, 1, 96000)
indices = enc(audio_example)
torch.onnx.export(
enc,
audio_example,
f"{export_prefix}encoder.onnx",
dynamic_axes={
"audio": {0: "batch_size", 2: "audio_length"},
},
do_constant_folding=False,
opset_version=18,
verbose=False,
input_names=["audio"],
output_names=["prompt"],
)
torch.onnx.export(
dec,
indices,
f"{export_prefix}decoder.onnx",
dynamic_axes={
"prompt": {0: "batch_size", 2: "frame_count"},
},
do_constant_folding=False,
opset_version=18,
verbose=False,
input_names=["prompt"],
output_names=["audio"],
)
test_example = torch.randn(1, 1, 96000 * 5)
encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx")
decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx")
# check graph has no error
onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0]
torch_enc_out = enc(test_example)
onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0]
torch_dec_out = dec(torch_enc_out)
if __name__ == "__main__":
main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")
import click
import torch
from loguru import logger
@click.command()
@click.argument("model_path")
@click.argument("output_path")
def main(model_path, output_path):
if model_path == output_path:
logger.error("Model path and output path are the same")
return
logger.info(f"Loading model from {model_path}")
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
torch.save(state_dict, output_path)
logger.info(f"Model saved to {output_path}")
if __name__ == "__main__":
main()
import base64
import ctypes
import io
import json
import os
import struct
from dataclasses import dataclass
from enum import Enum
from typing import AsyncGenerator, Union
import httpx
import numpy as np
import ormsgpack
import soundfile as sf
from fish_speech.utils.schema import (
ServeChatRequest,
ServeMessage,
ServeTextPart,
ServeVQGANDecodeRequest,
ServeVQGANEncodeRequest,
ServeVQPart,
)
class CustomAudioFrame:
def __init__(self, data, sample_rate, num_channels, samples_per_channel):
if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
ctypes.c_int16
):
raise ValueError(
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
)
self._data = bytearray(data)
self._sample_rate = sample_rate
self._num_channels = num_channels
self._samples_per_channel = samples_per_channel
@property
def data(self):
return memoryview(self._data).cast("h")
@property
def sample_rate(self):
return self._sample_rate
@property
def num_channels(self):
return self._num_channels
@property
def samples_per_channel(self):
return self._samples_per_channel
@property
def duration(self):
return self.samples_per_channel / self.sample_rate
def __repr__(self):
return (
f"CustomAudioFrame(sample_rate={self.sample_rate}, "
f"num_channels={self.num_channels}, "
f"samples_per_channel={self.samples_per_channel}, "
f"duration={self.duration:.3f})"
)
class FishE2EEventType(Enum):
SPEECH_SEGMENT = 1
TEXT_SEGMENT = 2
END_OF_TEXT = 3
END_OF_SPEECH = 4
ASR_RESULT = 5
USER_CODES = 6
@dataclass
class FishE2EEvent:
type: FishE2EEventType
frame: np.ndarray = None
text: str = None
vq_codes: list[list[int]] = None
client = httpx.AsyncClient(
timeout=None,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
keepalive_expiry=None,
),
)
class FishE2EAgent:
def __init__(self):
self.llm_url = "http://localhost:8080/v1/chat"
self.vqgan_url = "http://localhost:8080"
self.client = httpx.AsyncClient(timeout=None)
async def get_codes(self, audio_data, sample_rate):
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
audio_buffer.seek(0)
# Step 1: Encode audio using VQGAN
encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
encode_request_bytes = ormsgpack.packb(
encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
encode_response = await self.client.post(
f"{self.vqgan_url}/v1/vqgan/encode",
data=encode_request_bytes,
headers={"Content-Type": "application/msgpack"},
)
encode_response_data = ormsgpack.unpackb(encode_response.content)
codes = encode_response_data["tokens"][0]
return codes
async def stream(
self,
system_audio_data: np.ndarray | None,
user_audio_data: np.ndarray | None,
sample_rate: int,
num_channels: int,
chat_ctx: dict | None = None,
) -> AsyncGenerator[bytes, None]:
if system_audio_data is not None:
sys_codes = await self.get_codes(system_audio_data, sample_rate)
else:
sys_codes = None
if user_audio_data is not None:
user_codes = await self.get_codes(user_audio_data, sample_rate)
# Step 2: Prepare LLM request
if chat_ctx is None:
sys_parts = [
ServeTextPart(
text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
),
]
if system_audio_data is not None:
sys_parts.append(ServeVQPart(codes=sys_codes))
chat_ctx = {
"messages": [
ServeMessage(
role="system",
parts=sys_parts,
),
],
}
else:
if chat_ctx["added_sysaudio"] is False and sys_codes:
chat_ctx["added_sysaudio"] = True
chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
prev_messages = chat_ctx["messages"].copy()
if user_audio_data is not None:
yield FishE2EEvent(
type=FishE2EEventType.USER_CODES,
vq_codes=user_codes,
)
else:
user_codes = None
request = ServeChatRequest(
messages=prev_messages
+ (
[
ServeMessage(
role="user",
parts=[ServeVQPart(codes=user_codes)],
)
]
if user_codes
else []
),
streaming=True,
num_samples=1,
)
# Step 3: Stream LLM response and decode audio
buffer = b""
vq_codes = []
current_vq = False
async def decode_send():
nonlocal current_vq
nonlocal vq_codes
data = np.concatenate(vq_codes, axis=1).tolist()
# Decode VQ codes to audio
decode_request = ServeVQGANDecodeRequest(tokens=[data])
decode_response = await self.client.post(
f"{self.vqgan_url}/v1/vqgan/decode",
data=ormsgpack.packb(
decode_request,
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
),
headers={"Content-Type": "application/msgpack"},
)
decode_data = ormsgpack.unpackb(decode_response.content)
# Convert float16 audio data to int16
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
audio_data = (audio_data * 32768).astype(np.int16).tobytes()
audio_frame = CustomAudioFrame(
data=audio_data,
samples_per_channel=len(audio_data) // 2,
sample_rate=44100,
num_channels=1,
)
yield FishE2EEvent(
type=FishE2EEventType.SPEECH_SEGMENT,
frame=audio_frame,
vq_codes=data,
)
current_vq = False
vq_codes = []
async with self.client.stream(
"POST",
self.llm_url,
data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers={"Content-Type": "application/msgpack"},
) as response:
async for chunk in response.aiter_bytes():
buffer += chunk
while len(buffer) >= 4:
read_length = struct.unpack("I", buffer[:4])[0]
if len(buffer) < 4 + read_length:
break
body = buffer[4 : 4 + read_length]
buffer = buffer[4 + read_length :]
data = ormsgpack.unpackb(body)
if data["delta"] and data["delta"]["part"]:
if current_vq and data["delta"]["part"]["type"] == "text":
async for event in decode_send():
yield event
if data["delta"]["part"]["type"] == "text":
yield FishE2EEvent(
type=FishE2EEventType.TEXT_SEGMENT,
text=data["delta"]["part"]["text"],
)
elif data["delta"]["part"]["type"] == "vq":
vq_codes.append(np.array(data["delta"]["part"]["codes"]))
current_vq = True
if current_vq and vq_codes:
async for event in decode_send():
yield event
yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
# Example usage:
async def main():
import torchaudio
agent = FishE2EAgent()
# Replace this with actual audio data loading
with open("uz_story_en.m4a", "rb") as f:
audio_data = f.read()
audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
audio_data = (audio_data.numpy() * 32768).astype(np.int16)
stream = agent.stream(audio_data, sample_rate, 1)
if os.path.exists("audio_segment.wav"):
os.remove("audio_segment.wav")
async for event in stream:
if event.type == FishE2EEventType.SPEECH_SEGMENT:
# Handle speech segment (e.g., play audio or save to file)
with open("audio_segment.wav", "ab+") as f:
f.write(event.frame.data)
elif event.type == FishE2EEventType.ASR_RESULT:
print(event.text, flush=True)
elif event.type == FishE2EEventType.TEXT_SEGMENT:
print(event.text, flush=True, end="")
elif event.type == FishE2EEventType.END_OF_TEXT:
print("\nEnd of text reached.")
elif event.type == FishE2EEventType.END_OF_SPEECH:
print("End of speech reached.")
if __name__ == "__main__":
import asyncio
asyncio.run(main())
import itertools
import os
import re
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from pathlib import Path
import click
import numpy as np
from loguru import logger
from tqdm import tqdm
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.utils.file import load_filelist
# To avoid CPU overload
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
def task_generator_folder(root: Path, text_extension: str):
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
files = sorted(files)
grouped_files = defaultdict(list)
for file in tqdm(files, desc=f"Grouping {root}"):
p = str(file.parent)
speaker = file.parent.name
try:
if isinstance(text_extension, str):
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
else:
texts = [
file.with_suffix(ext).read_text(encoding="utf-8")
for ext in text_extension
]
except Exception as e:
logger.error(f"Failed to read text {file}: {e}")
continue
grouped_files[p].append((speaker, file, texts))
logger.info(
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
)
for i in grouped_files.values():
subset = [(f, t) for _, f, t in i]
yield i[0][0], subset, "folder"
def task_generator_filelist(filelist):
grouped_files = defaultdict(list)
for filename, speaker, _, text in load_filelist(filelist):
grouped_files[speaker].append((Path(filename), [text]))
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
for speaker, values in grouped_files.items():
yield speaker, values, "filelist"
def run_task(task):
name, subset, source = task
# Parse the files
sentences = []
for file, texts in subset:
np_file = file.with_suffix(".npy")
if np_file.exists() is False:
logger.warning(f"Can't find {np_file}")
continue
new_texts = []
for text in texts:
# Simple cleaning: replace { xxx } and < xxx > with space
text = re.sub(r"\{.*?\}", " ", text)
text = re.sub(r"<.*?>", " ", text)
text = re.sub(r"\s+", " ", text)
new_texts.append(text)
try:
semantics = np.load(np_file)
except Exception as e:
logger.error(f"Failed to parse {file}: {e}")
continue
if isinstance(semantics, np.ndarray):
semantics = semantics.tolist()
sentences.append(
Sentence(
texts=new_texts,
semantics=[Semantics(values=s) for s in semantics],
)
)
# Pack the sentences
return pack_pb_stream(
TextData(
source=source,
name=name,
sentences=sentences,
)
)
@click.command()
@click.option(
"--input",
type=click.Path(path_type=Path),
required=True,
help="A folder containing the dataset or a filelist",
multiple=True,
)
@click.option(
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
)
@click.option("--num-workers", type=int, default=16)
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
@click.option(
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
)
def main(input, output, num_workers, text_extension, shard_size):
generator_fns = []
for f in input:
assert f.exists(), f"{f} not found"
if f.is_dir():
generator_fn = task_generator_folder(f, text_extension)
else:
generator_fn = task_generator_filelist(f)
generator_fns.append(generator_fn)
generator_fn = itertools.chain(*generator_fns)
output.mkdir(parents=True, exist_ok=True)
dataset_fp = None
tar_idx = 0
written_size = 0
with Pool(num_workers) as p:
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
if dataset_fp is None:
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
dataset_fp.write(result)
written_size += len(result)
if written_size > shard_size * 1024 * 1024:
logger.info(f"Finished writing {tar_idx} shards to {output}")
dataset_fp.close()
dataset_fp = None
written_size = 0
tar_idx += 1
if dataset_fp is not None:
dataset_fp.close()
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
if __name__ == "__main__":
main()
import pyrootutils
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from transformers import AutoTokenizer
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from torch.utils.data import DataLoader
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
from fish_speech.models.text2semantic.inference import load_model
def smooth(
scalars: list[float], weight: float
) -> list[float]: # Weight between 0 and 1
last = scalars[0] # First value in the plot (first timestep)
smoothed = list()
for point in scalars:
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
smoothed.append(smoothed_val) # Save it
last = smoothed_val # Anchor the last smoothed value
return smoothed
@torch.inference_mode()
def analyze_one_model(loader, config, weight, max_length):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(
config,
weight,
device,
torch.bfloat16,
max_length,
compile=False,
)[0]
current_step = 0
model.eval()
semantic_loss_sum = torch.zeros(
max_length,
dtype=torch.float32,
device=device,
)
counter = torch.zeros(
max_length,
dtype=torch.long,
device=device,
)
for batch in loader:
batch = {k: v.to(device) for k, v in batch.items()}
labels = batch["labels"]
outputs = model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
# Generate labels
base_loss = F.cross_entropy(
token_logits.reshape(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
reduction="none",
)
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.reshape(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
reduction="none",
)
base_loss = base_loss.reshape(labels[:, 0].shape)
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
semantic_loss_frame = semantic_loss.mean(-1)
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
semantic_loss_sum[~pad] += loss_sample[~pad]
counter[~pad] += 1
current_step += 1
if current_step == 10:
break
semantic_loss = semantic_loss.cpu()
counter = counter.cpu()
xs, ys = [], []
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
if count > 0:
xs.append(i)
ys.append((loss / count).item()) # for better loss visualization
smoothed_ys = smooth(ys, 0.95)
# Unload model
del model
torch.cuda.empty_cache()
return xs, ys, smoothed_ys
def main():
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
max_length = 4096
ds = AutoAugTextDataset(
["data/protos/sft/云天河"],
tokenizer=tokenizer,
use_speaker=False,
interactive_prob=1.0,
max_length=max_length,
)
loader = DataLoader(
ds,
batch_size=8,
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
num_workers=0,
shuffle=False,
)
plt.figure(figsize=(10, 5), dpi=200)
plt.xlabel("Frame")
plt.ylabel("Loss")
plt.yscale("log")
plt.title("Semantic Loss")
plt.grid(which="both", axis="both")
plt.xlim(0, max_length)
tests = [
(
"pertrain-medium",
"dual_ar_2_codebook_medium",
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
),
(
"sft-medium",
"dual_ar_2_codebook_medium",
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
),
(
"sft-large",
"dual_ar_2_codebook_large",
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
),
]
for name, config, weight in tests:
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
plt.plot(xs, smoothed_ys, label=name)
plt.legend()
plt.savefig("semantic_loss.png")
if __name__ == "__main__":
main()
import os
import subprocess
import sys
#!/usr/bin/env python
def main():
# Make path relative to this file
script_path = os.path.join(
os.path.dirname(__file__), "../../fish_speech/models/text2semantic/inference.py"
)
subprocess.run(["python", script_path] + sys.argv[1:])
if __name__ == "__main__":
main()
import shutil
from copy import deepcopy
from pathlib import Path
import click
import hydra
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from fish_speech.models.text2semantic.llama import BaseTransformer
from fish_speech.models.text2semantic.lora import get_merged_state_dict
@click.command()
@click.option("--lora-config", type=str, default="r_8_alpha_16")
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
@click.option("--lora-weight", type=str, required=True)
@click.option("--output", type=str, required=True)
def merge(lora_config, base_weight, lora_weight, output):
output = Path(output)
logger.info(
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
)
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
cfg = compose(config_name=lora_config)
lora_config = instantiate(cfg)
logger.info(f"Loaded lora model with config {lora_config}")
llama_model = BaseTransformer.from_pretrained(
path=base_weight,
load_weights=True,
lora_config=lora_config,
)
logger.info(f"Loaded llama model")
llama_state_dict = llama_model.state_dict()
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
llama_state_dict_copy = deepcopy(llama_state_dict)
lora_state_dict = torch.load(lora_weight, map_location="cpu")
if "state_dict" in llama_state_dict:
llama_state_dict = llama_state_dict["state_dict"]
if "state_dict" in lora_state_dict:
lora_state_dict = lora_state_dict["state_dict"]
# remove prefix model.
if any(k.startswith("model.") for k in llama_state_dict.keys()):
llama_state_dict = {
k.replace("model.", ""): v
for k, v in llama_state_dict.items()
if k.startswith("model.")
}
if any(k.startswith("model.") for k in lora_state_dict.keys()):
lora_state_dict = {
k.replace("model.", ""): v
for k, v in lora_state_dict.items()
if k.startswith("model.")
}
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
merged_state_dict = llama_state_dict | lora_state_dict
llama_model.load_state_dict(merged_state_dict, strict=True)
logger.info(f"Merged model loaded")
# Trigger eval mode to merge lora
llama_model.eval()
llama_model.save_pretrained(output, drop_lora=True)
logger.info(f"Saved merged model to {output}, validating")
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
original_keys = set(llama_state_dict_copy.keys())
tolerance = 1e-5
for key in original_keys:
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
if diff_l1 > tolerance:
logger.info(f"Significant difference found in key: {key}")
break
if diff_l1 <= tolerance:
logger.warning(
"Merged model seems identical to the original model. Further validation might be needed."
)
else:
logger.info("Merged model is different from the original model, check passed")
if __name__ == "__main__":
merge()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import datetime
import shutil
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from pathlib import Path
import click
import torch
import torch.nn as nn
import torch.nn.functional as F
from fish_speech.models.text2semantic.inference import load_model
from fish_speech.models.text2semantic.llama import find_multiple
##### Quantization Primitives ######
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
return quant, scales, zero_points
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2
w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)
class QuantHandler:
def __init__(self, mod):
self.mod = mod
def create_quantized_state_dict(self) -> "StateDict":
pass
def convert_for_runtime(self) -> "nn.Module":
pass
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(
module,
name,
WeightOnlyInt8Linear(child.in_features, child.out_features),
)
else:
replace_linear_weight_only_int8_per_channel(child)
class WeightOnlyInt8QuantHandler:
def __init__(self, mod):
self.mod = mod
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
int8_weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(), -128, 127, torch.int8
)
cur_state_dict[f"{fqn}.weight"] = int8_weight
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
return cur_state_dict
def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
)
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
##### weight only int4 per channel groupwise quantized code ######
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x, weight_int4pack, groupsize, scales_and_zeros
)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
setattr(
module,
name,
WeightOnlyInt4Linear(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
padding=False,
),
)
elif padding:
setattr(
module,
name,
WeightOnlyInt4Linear(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
padding=True,
),
)
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
class WeightOnlyInt4QuantHandler:
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")
weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding:
import torch.nn.functional as F
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
print(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
)
continue
(
weight_int4pack,
scales_and_zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.bfloat16).to("cuda"),
self.groupsize,
self.inner_k_tiles,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
return cur_state_dict
def convert_for_runtime(self):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
return self.mod
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
padding: bool = True,
) -> None:
super().__init__()
self.padding = padding
if padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16
),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
def generate_folder_name():
now = datetime.datetime.now()
folder_name = now.strftime("%Y%m%d_%H%M%S")
return folder_name
@click.command()
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/fish-speech-1.4",
)
@click.option(
"--mode", type=str, default="int8", help="type of quantization to perform"
)
@click.option(
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
)
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
device = "cpu"
precision = torch.bfloat16
print("Loading model ...")
t0 = time.time()
model, _ = load_model(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=False,
)
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
now = timestamp if timestamp != "None" else generate_folder_name()
if mode == "int8":
print(
"Quantizing model weights for int8 weight-only symmetric per-channel quantization"
)
quant_handler = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
if (dst_name / vq_model).exists():
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"
elif mode == "int4":
print(
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
)
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
if (dst_name / vq_model).exists():
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"
else:
raise ValueError(
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
)
print(f"Writing quantized weights to {quantize_path}")
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
torch.save(quantized_state_dict, quantize_path)
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
if __name__ == "__main__":
quantize()
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
# Initialize a tokenizer
tokenizer = Tokenizer(models.BPE())
# Customize pre-tokenization and decoding
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
# Don't train the tokenizer
trainer = trainers.BpeTrainer(
vocab_size=0,
min_frequency=2,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=[
"<|begin_of_sequence|>",
"<|end_of_sequence|>",
"<|im_start|>",
"<|im_sep|>", # system, user, assistant, etc.
"<|im_end|>",
"<|semantic|>", # audio features
"<|pad|>",
],
)
# <|im_start|>user<|im_sep|>...<|im_end|>
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
tokenizer.train_from_iterator([], trainer=trainer)
print(len(tokenizer.get_vocab()))
x = tokenizer.encode(
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
).ids
print(x, len(x))
print(tokenizer.decode(x, skip_special_tokens=True))
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
pad_token="<|pad|>",
bos_token="<|begin_of_sequence|>",
eos_token="<|end_of_sequence|>",
)
# Try tokenizing a new sequence
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
encoded = tokenizer(sequence).input_ids
print("Test encoding....")
print(f"\tSentence: {sequence}")
print(f"\tEncoded: {encoded}")
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
print(f"\tDecoded: {tokenizer.decode(encoded)}")
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
import os
from argparse import ArgumentParser
from pathlib import Path
import pyrootutils
import torch
from loguru import logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.models.vqgan.inference import load_model as load_decoder_model
from fish_speech.utils.schema import ServeTTSRequest
from tools.webui import build_app
from tools.webui.inference import get_inference_wrapper
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-gradio-length", type=int, default=0)
parser.add_argument("--theme", type=str, default="light")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
# Check if MPS or CUDA is available
if torch.backends.mps.is_available():
args.device = "mps"
logger.info("mps is available, running on mps.")
elif not torch.cuda.is_available():
logger.info("CUDA is not available, running on CPU.")
args.device = "cpu"
logger.info("Loading Llama model...")
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("Decoder model loaded, warming up...")
# Create the inference engine
inference_engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
compile=args.compile,
precision=args.precision,
)
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference_engine.inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
)
)
logger.info("Warming up done, launching the web UI...")
# Get the inference function with the immutable arguments
inference_fct = get_inference_wrapper(inference_engine)
app = build_app(inference_fct, args.theme)
app.launch(show_api=True)
# FunASR Command Line Interface
This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
## Requirements
- Python >= 3.10
- PyTorch <= 2.3.1
- ffmpeg, pydub, audio-separator[gpu].
## Installation
Install the required packages:
```bash
pip install -e .[stable]
```
Make sure you have `ffmpeg` installed and available in your `PATH`.
## Usage
### Basic Usage
To run the tool with default settings:
```bash
python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
```
## Options
| Option | Description |
| :-----------------------: | :---------------------------------------------------------------------------: |
| --audio-dir | Directory containing audio or video files. |
| --save-dir | Directory to save processed audio files. |
| --device | Device to use for processing. Options: cuda (default) or cpu. |
| --language | Language of the transcription. Default is auto. |
| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
| --punc | Enable punctuation prediction. |
| --denoise | Enable noise reduction (vocal separation). |
## Example
To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
```bash
python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
```
## Additional Notes
- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
- The script will automatically create necessary directories in the `--save-dir`.
## Troubleshooting
If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import copy
import json
import logging
import os.path
import random
import re
import string
import time
import numpy as np
import torch
from funasr.download.download_model_from_hub import download_model
from funasr.download.file import download_from_url
from funasr.register import tables
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import export_utils, misc
from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
from funasr.utils.misc import deep_update
from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
from tqdm import tqdm
from .vad_utils import merge_vad, slice_padding_audio_samples
try:
from funasr.models.campplus.cluster_backend import ClusterBackend
from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
except:
pass
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
""" """
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str):
if data_in.startswith("http://") or data_in.startswith("https://"): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(
data_in
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding="utf-8") as fin:
for line in fin:
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
if data_in.endswith(
".jsonl"
): # file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines) > 1 else lines[0]
key = lines[0] if len(lines) > 1 else key
data_list.append(data)
key_list.append(key)
else:
if key is None:
# key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
key = misc.extract_filename_without_extension(data_in)
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(
data_type, (list, tuple)
): # mutiple inputs
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(
data_in=data_in_i, data_type=data_type_i
)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank, text]
data_list = data_in
key_list = []
for data_i in data_in:
if isinstance(data_i, str) and os.path.exists(data_i):
key = misc.extract_filename_without_extension(data_i)
else:
if key is None:
key = "rand_key_" + "".join(
random.choice(chars) for _ in range(13)
)
key_list.append(key)
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
if key is None:
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
return key_list, data_list
class AutoModel:
def __init__(self, **kwargs):
try:
from funasr.utils.version_checker import check_for_update
print(
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
)
check_for_update(disable=kwargs.get("disable_update", False))
except:
pass
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None)
vad_kwargs = (
{} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
)
if vad_model is not None:
logging.info("Building VAD model.")
vad_kwargs["model"] = vad_model
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
vad_kwargs["device"] = kwargs["device"]
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = (
{}
if kwargs.get("punc_kwargs", {}) is None
else kwargs.get("punc_kwargs", {})
)
if punc_model is not None:
logging.info("Building punc model.")
punc_kwargs["model"] = punc_model
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
punc_kwargs["device"] = kwargs["device"]
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = (
{} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
)
if spk_model is not None:
logging.info("Building SPK model.")
spk_kwargs["model"] = spk_model
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
spk_kwargs["device"] = kwargs["device"]
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend().to(kwargs["device"])
spk_mode = kwargs.get("spk_mode", "punc_segment")
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error(
"spk_mode should be one of default, vad_segment and punc_segment."
)
self.spk_mode = spk_mode
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs.get("model_path")
@staticmethod
def build_model(**kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info(
"download models from model hub: {}".format(kwargs.get("hub", "ms"))
)
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
torch.set_num_threads(kwargs.get("ncpu", 4))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
kwargs["token_list"] = (
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
)
kwargs["token_list"] = (
tokenizer.get_vocab()
if hasattr(tokenizer, "get_vocab")
else kwargs["token_list"]
)
vocab_size = (
len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
)
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
vocab_size = tokenizer.get_vocab_size()
else:
vocab_size = -1
kwargs["tokenizer"] = tokenizer
# build frontend
frontend = kwargs.get("frontend", None)
kwargs["input_size"] = None
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs.get("frontend_conf", {}))
kwargs["input_size"] = (
frontend.output_size() if hasattr(frontend, "output_size") else None
)
kwargs["frontend"] = frontend
# build model
model_class = tables.model_classes.get(kwargs["model"])
assert model_class is not None, f'{kwargs["model"]} is not registered'
model_conf = {}
deep_update(model_conf, kwargs.get("model_conf", {}))
deep_update(model_conf, kwargs)
model = model_class(**model_conf, vocab_size=vocab_size)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
if os.path.exists(init_param):
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
path=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
scope_map=kwargs.get("scope_map", []),
excludes=kwargs.get("excludes", None),
)
else:
print(f"error, init_param does not exist!: {init_param}")
# fp16
if kwargs.get("fp16", False):
model.to(torch.float16)
elif kwargs.get("bf16", False):
model.to(torch.bfloat16)
model.to(device)
if not kwargs.get("disable_log", True):
tables.print()
return model, kwargs
def __call__(self, *args, **cfg):
kwargs = self.kwargs
deep_update(kwargs, cfg)
res = self.model(*args, kwargs)
return res
def generate(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.inference(input, input_len=input_len, **cfg)
else:
return self.inference_with_vad(input, input_len=input_len, **cfg)
def inference(
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
):
kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
deep_update(kwargs, cfg)
model = self.model if model is None else model
model.eval()
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
disable_pbar = self.kwargs.get("disable_pbar", False)
pbar = (
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
if not disable_pbar
else None
)
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and kwargs.get(
"data_type", None
) == "fbank": # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
time1 = time.perf_counter()
with torch.no_grad():
res = model.inference(**batch, **kwargs)
if isinstance(res, (list, tuple)):
results = res[0] if len(res) > 0 else [{"text": ""}]
meta_data = res[1] if len(res) > 1 else {}
time2 = time.perf_counter()
asr_result_list.extend(results)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = f"{speed_stats}, "
if pbar:
pbar.update(end_idx - beg_idx)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape
if pbar:
# pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
def vad(self, input, input_len=None, **cfg):
kwargs = self.kwargs
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
beg_vad = time.time()
res = self.inference(
input,
input_len=input_len,
model=self.vad_model,
kwargs=self.vad_kwargs,
**cfg,
)
end_vad = time.time()
# FIX(gcf): concat the vad clips for sense vocie model for better aed
if cfg.get("merge_vad", False):
for i in range(len(res)):
res[i]["value"] = merge_vad(
res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
)
elapsed = end_vad - beg_vad
return elapsed, res
def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
kwargs = self.kwargs
# step.2 compute asr model
model = self.model
deep_update(kwargs, cfg)
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(
input, input_len=input_len, data_type=kwargs.get("data_type", None)
)
results_ret_list = []
time_speech_total_all_samples = 1e-6
beg_total = time.time()
pbar_total = (
tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
if not kwargs.get("disable_pbar", False)
else None
)
for i in range(len(vad_res)):
key = vad_res[i]["key"]
vadsegments = vad_res[i]["value"]
input_i = data_list[i]
fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
speech = load_audio_text_image_video(
input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
)
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
if not len(sorted_data):
results_ret_list.append({"key": key, "text": "", "timestamp": []})
logging.info("decoding, utt: {}, empty speech".format(key))
continue
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
)
if kwargs["device"] == "cpu":
batch_size = 0
beg_idx = 0
beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths / 16000
time_speech_total_all_samples += time_speech_total_per_sample
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
all_segments = []
max_len_in_batch = 0
end_idx = 1
for j, _ in enumerate(range(0, n)):
# pbar_sample.update(1)
sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
potential_batch_length = max(max_len_in_batch, sample_length) * (
j + 1 - beg_idx
)
# batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
if (
j < n - 1
and sample_length < batch_size_threshold_ms
and potential_batch_length < batch_size
):
max_len_in_batch = max(max_len_in_batch, sample_length)
end_idx += 1
continue
speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
speech, speech_lengths, sorted_data[beg_idx:end_idx]
)
results = self.inference(
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
)
for _b in range(len(speech_j)):
results[_b]["interval"] = intervals[_b]
if self.spk_model is not None:
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [
[
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
np.array(speech_j[_b]),
]
]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.inference(
speech_b,
input_len=None,
model=self.spk_model,
kwargs=kwargs,
**cfg,
)
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
beg_idx = end_idx
end_idx += 1
max_len_in_batch = sample_length
if len(results) < 1:
continue
results_sorted.extend(results)
# end_asr_total = time.time()
# time_escape_total_per_sample = end_asr_total - beg_asr_total
# pbar_sample.update(1)
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
cur = results_sorted[j]
pattern = r"<\|([^|]+)\|>"
emotion_string = re.findall(pattern, cur["text"])
cur["text"] = re.sub(pattern, "", cur["text"])
cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
if self.punc_model is not None and len(cur["text"].strip()) > 0:
deep_update(self.punc_kwargs, cfg)
punc_res = self.inference(
cur["text"],
model=self.punc_model,
kwargs=self.punc_kwargs,
**cfg,
)
cur["text"] = punc_res[0]["text"]
restored_data[index] = cur
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
if pbar_total:
pbar_total.update(1)
pbar_total.set_description(
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
f"time_escape: {time_escape_total_per_sample:0.3f}"
)
# end_total = time.time()
# time_escape_total_all_samples = end_total - beg_total
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
return restored_data
def export(self, input=None, **cfg):
"""
:param input:
:param type:
:param quantize:
:param fallback_num:
:param calib_num:
:param opset_version:
:param cfg:
:return:
"""
device = cfg.get("device", "cpu")
model = self.model.to(device=device)
kwargs = self.kwargs
deep_update(kwargs, cfg)
kwargs["device"] = device
del kwargs["model"]
model.eval()
type = kwargs.get("type", "onnx")
key_list, data_list = prepare_data_iterator(
input, input_len=None, data_type=kwargs.get("data_type", None), key=None
)
with torch.no_grad():
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
return export_dir
import gc
import os
import re
from audio_separator.separator import Separator
os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
import json
import subprocess
from pathlib import Path
import click
import torch
from loguru import logger
from pydub import AudioSegment
from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
from tools.sensevoice.auto_model import AutoModel
def uvr5_cli(
audio_dir: Path,
output_folder: Path,
audio_files: list[Path] | None = None,
output_format: str = "flac",
model: str = "BS-Roformer-Viperx-1297.ckpt",
):
# ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
sepr = Separator(
model_file_dir=os.environ["UVR5_CACHE"],
output_dir=output_folder,
output_format=output_format,
)
dictmodel = {
"BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
"BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
"BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
"Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
}
roformer_model = dictmodel[model]
sepr.load_model(roformer_model)
if audio_files is None:
audio_files = list_files(
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
total_files = len(audio_files)
print(f"{total_files} audio files found")
res = []
for audio in tqdm(audio_files, desc="Denoising: "):
file_path = str(audio_dir / audio)
sep_out = sepr.separate(file_path)
if isinstance(sep_out, str):
res.append(sep_out)
elif isinstance(sep_out, list):
res.extend(sep_out)
del sepr
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return res, roformer_model
def get_sample_rate(media_path: Path):
result = subprocess.run(
[
"ffprobe",
"-v",
"quiet",
"-print_format",
"json",
"-show_streams",
str(media_path),
],
capture_output=True,
text=True,
check=True,
)
media_info = json.loads(result.stdout)
for stream in media_info.get("streams", []):
if stream.get("codec_type") == "audio":
return stream.get("sample_rate")
return "44100" # Default sample rate if not found
def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
sr = get_sample_rate(src_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
if src_path.resolve() == out_path.resolve():
output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
else:
output = str(out_path)
subprocess.run(
[
"ffmpeg",
"-loglevel",
"error",
"-i",
str(src_path),
"-acodec",
"pcm_s16le" if out_fmt == "wav" else "flac",
"-ar",
sr,
"-ac",
"1",
"-y",
output,
],
check=True,
)
return out_path
def convert_video_to_audio(video_path: Path, audio_dir: Path):
cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
vocals = [
p
for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
if p.suffix in AUDIO_EXTENSIONS
]
if len(vocals) > 0:
return vocals[0]
audio_path = cur_dir / f"{video_path.stem}.wav"
convert_to_mono(video_path, audio_path)
return audio_path
@click.command()
@click.option("--audio-dir", required=True, help="Directory containing audio files")
@click.option(
"--save-dir", required=True, help="Directory to save processed audio files"
)
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
@click.option("--language", default="auto", help="Language of the transcription")
@click.option(
"--max_single_segment_time",
default=20000,
type=int,
help="Maximum of Output single audio duration(ms)",
)
@click.option("--fsmn-vad/--silero-vad", default=False)
@click.option("--punc/--no-punc", default=False)
@click.option("--denoise/--no-denoise", default=False)
@click.option("--save_emo/--no_save_emo", default=False)
def main(
audio_dir: str,
save_dir: str,
device: str,
language: str,
max_single_segment_time: int,
fsmn_vad: bool,
punc: bool,
denoise: bool,
save_emo: bool,
):
audios_path = Path(audio_dir)
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
video_files = list_files(
path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
)
v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
if denoise:
VOCAL = "_(Vocals)"
original_files = [
p
for p in audios_path.glob("**/*")
if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
]
_, cur_model = uvr5_cli(
audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
)
need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
need_remove.extend(original_files)
for _ in need_remove:
_.unlink()
vocal_files = [
p
for p in audios_path.glob("**/*")
if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
]
for f in vocal_files:
fn, ext = f.stem, f.suffix
v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
if v_pos != -1:
new_fn = fn[: v_pos + len(VOCAL)]
new_f = f.with_name(new_fn + ext)
f = f.rename(new_f)
convert_to_mono(f, f, "flac")
f.unlink()
audio_files = list_files(
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
logger.info("Loading / Downloading Funasr model...")
model_dir = "iic/SenseVoiceSmall"
vad_model = "fsmn-vad" if fsmn_vad else None
vad_kwargs = {"max_single_segment_time": max_single_segment_time}
punc_model = "ct-punc" if punc else None
manager = AutoModel(
model=model_dir,
trust_remote_code=False,
vad_model=vad_model,
vad_kwargs=vad_kwargs,
punc_model=punc_model,
device=device,
)
if not fsmn_vad and vad_model is None:
vad_model = load_silero_vad()
logger.info("Model loaded.")
pattern = re.compile(r"_\d{3}\.")
for file_path in tqdm(audio_files, desc="Processing audio file"):
if pattern.search(file_path.name):
# logger.info(f"Skipping {file_path} as it has already been processed.")
continue
file_stem = file_path.stem
file_suffix = file_path.suffix
rel_path = Path(file_path).relative_to(audio_dir)
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
audio = AudioSegment.from_file(file_path)
cfg = dict(
cache={},
language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=False,
batch_size_s=60,
)
if fsmn_vad:
elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
else:
wav = read_audio(
str(file_path)
) # backend (sox, soundfile, or ffmpeg) required!
audio_key = file_path.stem
audio_val = []
speech_timestamps = get_speech_timestamps(
wav,
vad_model,
max_speech_duration_s=max_single_segment_time // 1000,
return_seconds=True,
)
audio_val = [
[int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
for timestamp in speech_timestamps
]
vad_res = []
vad_res.append(dict(key=audio_key, value=audio_val))
res = manager.inference_with_vadres(
input=str(file_path), vad_res=vad_res, **cfg
)
for i, info in enumerate(res):
[start_ms, end_ms] = info["interval"]
text = info["text"]
emo = info["emo"]
sliced_audio = audio[start_ms:end_ms]
audio_save_path = (
save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
)
sliced_audio.export(audio_save_path, format=file_suffix[1:])
print(f"Exported {audio_save_path}: {text}")
transcript_save_path = (
save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
)
with open(
transcript_save_path,
"w",
encoding="utf-8",
) as f:
f.write(text)
if save_emo:
emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
with open(
emo_save_path,
"w",
encoding="utf-8",
) as f:
f.write(emo)
if audios_path.resolve() == save_path.resolve():
file_path.unlink()
if __name__ == "__main__":
main()
exit(0)
from funasr.utils.postprocess_utils import rich_transcription_postprocess
# Load the audio file
audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
model_dir = "iic/SenseVoiceSmall"
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
m.eval()
res = m.inference(
data_in=f"{kwargs['model_path']}/example/zh.mp3",
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=False,
ban_emo_unk=False,
**kwargs,
)
print(res)
text = rich_transcription_postprocess(res[0][0]["text"])
print(text)
import torch
from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
speech_i = speech[0, bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
intervals = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths)
speech_i = speech[bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
intervals.append([bed_idx // 16, end_idx // 16])
return speech_list, speech_lengths_list, intervals
def merge_vad(vad_result, max_length=15000, min_length=0):
new_result = []
if len(vad_result) <= 1:
return vad_result
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
time_step = sorted(list(set(time_step)))
if len(time_step) == 0:
return []
bg = 0
for i in range(len(time_step) - 1):
time = time_step[i]
if time_step[i + 1] - bg < max_length:
continue
if time - bg > min_length:
new_result.append([bg, time])
# if time - bg < max_length * 1.5:
# new_result.append([bg, time])
# else:
# split_num = int(time - bg) // max_length + 1
# spl_l = int(time - bg) // split_num
# for j in range(split_num):
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
bg = time
new_result.append([bg, time_step[-1]])
return new_result
import struct
from functools import partial
import ormsgpack
from tools.server.agent.generate import generate_responses
from tools.server.agent.pre_generation_utils import prepare_messages
def execute_request(input_queue, tokenizer, config, request, device):
"""
This function prepares the conversation, encodes the request,
sends the generation request, and handles decoding/streaming.
It returns a response generator (ServeResponse or ServeStreamResponse).
"""
prompt, im_end_id = prepare_messages(request, tokenizer, config)
yield from generate_responses(
input_queue, tokenizer, config, request, prompt, im_end_id, device
)
def response_generator(req, llama_queue, tokenizer, config, device):
"""
Non-streaming response wrapper for the chat endpoint.
Only returns the final result.
"""
generator = execute_request(llama_queue, tokenizer, config, req, device)
return next(generator)
async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
"""
Streaming response wrapper for the chat endpoint.
Returns the response in chunks.
"""
generator = execute_request(llama_queue, tokenizer, config, req, device)
for i in generator:
if json_mode:
body = i.model_dump_json().encode("utf-8")
yield b"data: " + body + b"\n\n"
else:
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
yield struct.pack("I", len(body)) + body
def get_response_generator(
llama_queue, tokenizer, config, req, device, json_mode
) -> partial:
"""
Get the correct response generator based on the request.
"""
if not req.streaming:
return partial(response_generator, req, llama_queue, tokenizer, config, device)
else:
return partial(
streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
)
import time
from fish_speech.utils.schema import ServeMessage, ServeResponse, ServeStreamResponse
from tools.server.agent.generation_utils import (
initialize_decode_buffers,
process_response_tokens,
send_reset_buffer,
)
from tools.server.agent.pre_generation_utils import (
create_generation_request,
send_generation_request,
)
def generate_responses(
input_queue, tokenizer, config, request, prompt, im_end_id, device
):
"""
Main generation function that handles the conversation, encodes the request,
sends the generation request, and handles decoding/streaming.
It returns a response generator (ServeResponse or ServeStreamResponse).
"""
stats = {}
start = time.time()
stats["start_time"] = start
stats["tokens_count"] = 0
# Prepare and send the generation request
req = create_generation_request(prompt, request, im_end_id, device)
response_queue = send_generation_request(input_queue, req)
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
while True:
response = response_queue.get()
# Handle abnormal finish or error
if response in ["stop", "error"]:
finish_reason = response
break
# Process the response tokens
is_first_token = stats["tokens_count"] == 0
responses = process_response_tokens(
response,
tokenizer,
config,
request,
decode_buffer,
parts,
finished,
im_end_id,
stats,
start,
is_first_token,
)
# Yield the responses if streaming
if request.streaming and responses:
for r in responses:
yield r
stats["tokens_count"] += 1
# Check if all samples are finished
if all(finished):
finish_reason = "stop"
break
# Finalize the response
final_responses = finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
)
for fr in final_responses:
yield fr
def finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
):
"""
Finalize the response by sending the remaining text buffers.
"""
responses = []
# Send the remaining text buffers
for sample_id in range(request.num_samples):
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
# Calculate the final stats
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
stats["total_tokens"] = stats["tokens_count"]
# If streaming, send the final chunks for each sample
if request.streaming:
for sample_id in range(request.num_samples):
if finished[sample_id]:
continue
responses.append(
ServeStreamResponse(
finish_reason=finish_reason, stats=stats, sample_id=sample_id
)
)
else:
# If not streaming, send the full messages for each sample
full_messages = [
ServeMessage(role="assistant", parts=parts[i])
for i in range(request.num_samples)
]
responses.append(
ServeResponse(
messages=full_messages,
finish_reason=finish_reason,
stats=stats,
)
)
return responses
import time
from fish_speech.utils.schema import (
ServeStreamDelta,
ServeStreamResponse,
ServeTextPart,
ServeVQPart,
)
def initialize_decode_buffers(num_samples):
"""Initialise the decode buffers for each sample."""
decode_buffer = [[] for _ in range(num_samples)]
parts = [[] for _ in range(num_samples)]
finished = [False for _ in range(num_samples)]
return decode_buffer, parts, finished
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
"""Send the remaining text buffer for a sample."""
if len(decode_buffer[sample_id]) == 0:
return []
decoded = tokenizer.decode(decode_buffer[sample_id])
part = ServeTextPart(text=decoded)
responses = []
if request.streaming:
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
else:
parts[sample_id].append(part)
decode_buffer[sample_id] = []
return responses
def handle_semantic_tokens(tokens, config, sample_id, parts, request):
"""Handle the semantic tokens returned by the model."""
responses = []
_tokens = tokens[1:].clone()
if not config.share_codebook_embeddings:
for i in range(len(_tokens)):
_tokens[i] -= config.codebook_size * i
# If streaming, send the VQ parts directly
if request.streaming:
responses.append(
ServeStreamResponse(
sample_id=sample_id,
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
)
)
else:
# If not streaming, accumulate the VQ parts
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
else:
# Accumulate the codes
for codebook_id, value in enumerate(_tokens):
parts[sample_id][-1].codes[codebook_id].append(value.item())
return responses
def process_response_tokens(
response,
tokenizer,
config,
request,
decode_buffer,
parts,
finished,
im_end_id,
stats,
start,
is_first_token,
):
"""Process the response tokens returned by the model."""
responses = []
for sample_id, tokens in enumerate(response):
if finished[sample_id]:
continue
# End of the conversation
if tokens[0] == im_end_id:
finished[sample_id] = True
# Send the remaining text buffer
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
if request.streaming:
responses.append(
ServeStreamResponse(
sample_id=sample_id,
finish_reason="stop",
stats=stats,
)
)
continue
# Check if the token is semantic
is_semantic = (
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
)
if is_semantic:
# Before the semantic tokens, send the remaining text buffer
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
responses.extend(
handle_semantic_tokens(tokens, config, sample_id, parts, request)
)
else:
# Accumulate the text tokens (not implemented?)
decode_buffer[sample_id].append(tokens[0, 0])
if is_first_token:
stats["time_to_first_token"] = (time.time() - start) * 1000
return responses
import queue
from fish_speech.conversation import Conversation, Message
from fish_speech.models.text2semantic.inference import GenerateRequest
from fish_speech.tokenizer import IM_END_TOKEN
def prepare_messages(request, tokenizer, config):
"""
Reorganise the provided list of messages into a conversation.
Encode the conversation for inference.
"""
# Convert the messages to ConversationMessage objects
messages = [msg.to_conversation_message() for msg in request.messages]
if len(messages) < 1:
raise ValueError("At least one message is required")
# Check the last message to determine the next step
last_role = messages[-1].role
match last_role:
case "user":
# The last message is from the user, ask the assistant to respond with a new message
messages.append(
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
)
case "raw":
# The last message is raw text, ask the assistant to complete it
messages[-1].add_im_start = False
messages[-1].add_im_end = False
messages[-1].modality = "voice"
case "assistant":
# The last message is from the assistant, ask the assistant to continue
messages[-1].add_im_end = False
case _:
# We expect it to be assistant if not user or raw
raise ValueError("The last message must be from the assistant, user or raw")
# Create a conversation object and encode it for inference
conv = Conversation(messages=messages)
prompt = conv.encode_for_inference(
tokenizer=tokenizer, num_codebooks=config.num_codebooks
)
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
return prompt, im_end_id
def create_generation_request(prompt, request, im_end_id, device):
"""
Convert the request into a dictionary that can be sent to the model for generation.
"""
req = {
"prompt": prompt.to(device),
"max_new_tokens": request.max_new_tokens,
"im_end_id": im_end_id,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"num_samples": request.num_samples,
"early_stop_threshold": request.early_stop_threshold,
}
return req
def send_generation_request(input_queue, req):
"""
Send the generation request to the model and return a queue to get the response.
"""
response_queue = queue.Queue()
input_queue.put(GenerateRequest(req, response_queue))
return response_queue
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment