Commit b75857fb authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import HTTPException, HttpRequest
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--api-key", type=str, default=None)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any, ContentType("application/msgpack"), ContentType("application/json")
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack, application/json"},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
else:
return "application/octet-stream"
import traceback
from http import HTTPStatus
from kui.asgi import HTTPException, JSONResponse
class ExceptionHandler:
async def http_exception_handler(self, exc: HTTPException):
return JSONResponse(
dict(
statusCode=exc.status_code,
message=exc.content,
error=HTTPStatus(exc.status_code).phrase,
),
exc.status_code,
exc.headers,
)
async def other_exception_handler(self, exc: Exception):
traceback.print_exc()
status = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse(
dict(statusCode=status, message=str(exc), error=status.phrase),
status,
)
from http import HTTPStatus
import numpy as np
from kui.asgi import HTTPException
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.utils.schema import ServeTTSRequest
AMPLITUDE = 32768 # Needs an explaination
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
"""
Wrapper for the inference function.
Used in the API server.
"""
count = 0
for result in engine.inference(req):
match result.code:
case "header":
if isinstance(result.audio, tuple):
yield result.audio[1]
case "error":
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content=str(result.error),
)
case "segment":
count += 1
if isinstance(result.audio, tuple):
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
case "final":
count += 1
if isinstance(result.audio, tuple):
yield result.audio[1]
return None # Stop the generator
if count == 0:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)
import torch
from funasr import AutoModel
from loguru import logger
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.text2semantic.inference import (
launch_thread_safe_queue,
launch_thread_safe_queue_agent,
)
from fish_speech.models.vqgan.inference import load_model as load_decoder_model
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
class ModelManager:
def __init__(
self,
mode: str,
device: str,
half: bool,
compile: bool,
asr_enabled: bool,
llama_checkpoint_path: str,
decoder_checkpoint_path: str,
decoder_config_name: str,
) -> None:
self.mode = mode
self.device = device
self.half = half
self.compile = compile
self.precision = torch.half if half else torch.bfloat16
# Check if MPS or CUDA is available
if torch.backends.mps.is_available():
self.device = "mps"
logger.info("mps is available, running on mps.")
elif not torch.cuda.is_available():
self.device = "cpu"
logger.info("CUDA is not available, running on CPU.")
# Load the ASR model if enabled
if asr_enabled:
self.load_asr_model(self.device)
# Load the TTS models
self.load_llama_model(
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
)
self.load_decoder_model(
decoder_config_name, decoder_checkpoint_path, self.device
)
self.tts_inference_engine = TTSInferenceEngine(
llama_queue=self.llama_queue,
decoder_model=self.decoder_model,
precision=self.precision,
compile=self.compile,
)
# Warm up the models
if self.mode == "tts":
self.warm_up(self.tts_inference_engine)
def load_asr_model(self, device, hub="ms") -> None:
self.asr_model = AutoModel(
model=ASR_MODEL_NAME,
device=device,
disable_pbar=True,
hub=hub,
)
logger.info("ASR model loaded.")
def load_llama_model(
self, checkpoint_path, device, precision, compile, mode
) -> None:
if mode == "tts":
self.llama_queue = launch_thread_safe_queue(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=compile,
)
elif mode == "agent":
self.llama_queue, self.tokenizer, self.config = (
launch_thread_safe_queue_agent(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=compile,
)
)
else:
raise ValueError(f"Invalid mode: {mode}")
logger.info("LLAMA model loaded.")
def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
self.decoder_model = load_decoder_model(
config_name=config_name,
checkpoint_path=checkpoint_path,
device=device,
)
logger.info("Decoder model loaded.")
def warm_up(self, tts_inference_engine) -> None:
request = ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.2,
temperature=0.7,
format="wav",
)
list(inference(request, tts_inference_engine))
logger.info("Models warmed up.")
import io
import re
import librosa
import torch
import torchaudio
from cachetools import LRUCache, cached
CACHE_MAXSIZE = 10000
MICRO_BATCH_SIZE = 8
ASR_SAMPLE_RATE = 16000
HUGE_GAP_THRESHOLD = 4000
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_encode(model, audios_list: list[bytes]):
audios: list[torch.Tensor] = [
(
torch.from_numpy(
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
)[None]
if isinstance(audio, bytes)
else audio
)
for audio in audios_list
]
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
max_length = lengths.max().item()
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
padded = torch.stack(
[
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
for audio in audios
]
).to(model.device)
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
features, feature_lengths = features.cpu(), feature_lengths.cpu()
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
@cached(
cache=LRUCache(maxsize=CACHE_MAXSIZE),
key=lambda model, audios: (model.device, tuple(audios)),
)
def cached_vqgan_batch_encode(model, audios: list[bytes]):
return batch_encode(model, audios)
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_vqgan_decode(model, features):
lengths = torch.tensor(
[feature.shape[-1] for feature in features], device=model.device
)
max_length = lengths.max().item()
padded = torch.stack(
[
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
for feature in features
]
).to(model.device)
# If bs too large, we do micro batch decode
audios, audio_lengths = [], []
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
audio, audio_length = model.decode(
padded[i : i + MICRO_BATCH_SIZE],
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
)
audios.append(audio)
audio_lengths.append(audio_length)
audios = torch.cat(audios, dim=0)
audio_lengths = torch.cat(audio_lengths, dim=0)
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
@torch.no_grad()
def batch_asr(model, lock, audios, sr, language="auto"):
resampled_audios = []
for audio in audios:
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
assert audio.ndim == 1
resampled_audios.append(audio)
with lock:
res = model.generate(
input=resampled_audios,
batch_size=len(resampled_audios),
language=language,
use_itn=True,
)
results = []
for r, audio in zip(res, audios):
text = r["text"]
text = re.sub(r"<\|.*?\|>", "", text)
duration = len(audio) / sr * 1000
huge_gap = False
if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 4 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
huge_gap = True
break
# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
huge_gap = True
results.append(
{
"text": text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results
import io
import os
import time
from http import HTTPStatus
import numpy as np
import ormsgpack
import soundfile as sf
import torch
from kui.asgi import Body, HTTPException, JSONResponse, Routes, StreamResponse, request
from loguru import logger
from typing_extensions import Annotated
from fish_speech.utils.schema import (
ServeASRRequest,
ServeASRResponse,
ServeChatRequest,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
ServeVQGANEncodeRequest,
ServeVQGANEncodeResponse,
)
from tools.server.agent import get_response_generator
from tools.server.api_utils import (
buffer_to_async_generator,
get_content_type,
inference_async,
)
from tools.server.inference import inference_wrapper as inference
from tools.server.model_manager import ModelManager
from tools.server.model_utils import (
batch_asr,
batch_vqgan_decode,
cached_vqgan_batch_encode,
)
MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
routes = Routes()
@routes.http.post("/v1/health")
async def health():
return JSONResponse({"status": "ok"})
@routes.http.post("/v1/vqgan/encode")
async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Encode the audio
start_time = time.time()
tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
# Return the response
return ormsgpack.packb(
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
@routes.http.post("/v1/vqgan/decode")
async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Decode the audio
tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
start_time = time.time()
audios = batch_vqgan_decode(decoder_model, tokens)
logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
audios = [audio.astype(np.float16).tobytes() for audio in audios]
# Return the response
return ormsgpack.packb(
ServeVQGANDecodeResponse(audios=audios),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
@routes.http.post("/v1/asr")
async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
asr_model = model_manager.asr_model
lock = request.app.state.lock
# Perform ASR
start_time = time.time()
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
audios = [torch.from_numpy(audio).float() for audio in audios]
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
raise HTTPException(status_code=400, content="Audio length is too long")
transcriptions = batch_asr(
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
)
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
# Return the response
return ormsgpack.packb(
ServeASRResponse(transcriptions=transcriptions),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
@routes.http.post("/v1/tts")
async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
# Get the model from the app
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
sample_rate = engine.decoder_model.spec_transform.sample_rate
# Check if the text is too long
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Text is too long, max length is {app_state.max_text_length}",
)
# Check if streaming is enabled
if req.streaming and req.format != "wav":
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content="Streaming only supports WAV format",
)
# Perform TTS
if req.streaming:
return StreamResponse(
iterable=inference_async(req, engine),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
else:
fake_audios = next(inference(req, engine))
buffer = io.BytesIO()
sf.write(
buffer,
fake_audios,
sample_rate,
format=req.format,
)
return StreamResponse(
iterable=buffer_to_async_generator(buffer.getvalue()),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
@routes.http.post("/v1/chat")
async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
# Check that the number of samples requested is correct
if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
)
# Get the type of content provided
content_type = request.headers.get("Content-Type", "application/json")
json_mode = "application/json" in content_type
# Get the models from the app
model_manager: ModelManager = request.app.state.model_manager
llama_queue = model_manager.llama_queue
tokenizer = model_manager.tokenizer
config = model_manager.config
device = request.app.state.device
# Get the response generators
response_generator = get_response_generator(
llama_queue, tokenizer, config, req, device, json_mode
)
# Return the response in the correct format
if req.streaming is False:
result = response_generator()
if json_mode:
return JSONResponse(result.model_dump())
else:
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
return StreamResponse(
iterable=response_generator(), content_type="text/event-stream"
)
import random
from multiprocessing import Pool
from pathlib import Path
import click
import librosa
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
threshold = 10 ** (-50 / 20.0)
def process(file):
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
loudness = librosa.feature.rms(
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
)[0]
for i in range(len(loudness) - 1, 0, -1):
if loudness[i] > threshold:
break
end_silent_time = (len(loudness) - i) * 512 / sample_rate
if end_silent_time <= 0.3:
random_time = random.uniform(0.3, 0.7) - end_silent_time
waveform = F.pad(
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
)
for i in range(len(loudness)):
if loudness[i] > threshold:
break
start_silent_time = i * 512 / sample_rate
if start_silent_time > 0.02:
waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
@click.command()
@click.argument("source", type=Path)
@click.option("--num-workers", type=int, default=12)
def main(source, num_workers):
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
with Pool(num_workers) as p:
list(tqdm(p.imap_unordered(process, files), total=len(files)))
if __name__ == "__main__":
main()
import math
from pathlib import Path
from random import Random
import click
from loguru import logger
from pydub import AudioSegment
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
@click.command()
@click.argument("root", type=click.Path(exists=True, path_type=Path))
@click.option("--val-ratio", type=float, default=None)
@click.option("--val-count", type=int, default=None)
@click.option("--filelist", default=None, type=Path)
@click.option("--min-duration", default=None, type=float)
@click.option("--max-duration", default=None, type=float)
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
if filelist:
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
if min_duration is None and max_duration is None:
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
else:
filtered_files = []
for file in tqdm(files):
try:
audio = AudioSegment.from_file(str(file))
duration = len(audio) / 1000.0
if min_duration is not None and duration < min_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
)
continue
if max_duration is not None and duration > max_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
)
continue
filtered_files.append(str(file.relative_to(root)))
except Exception as e:
logger.info(f"Error processing {file}: {e}")
logger.info(
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
)
Random(42).shuffle(filtered_files)
if val_count is None and val_ratio is None:
logger.info("Validation ratio and count not specified, using min(20%, 100)")
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
elif val_count is not None and val_ratio is not None:
logger.error("Cannot specify both val_count and val_ratio")
return
elif val_count is not None:
if val_count < 1 or val_count > len(filtered_files):
logger.error("val_count must be between 1 and number of files")
return
val_size = val_count
else:
val_size = math.ceil(len(filtered_files) * val_ratio)
logger.info(f"Using {val_size} files for validation")
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[val_size:]))
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[:val_size]))
logger.info("Done")
if __name__ == "__main__":
main()
import os
import subprocess as sp
import sys
import time
from datetime import timedelta
from functools import lru_cache
from pathlib import Path
from random import Random
import click
import numpy as np
import torch
import torchaudio
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
# This file is used to convert the audio files to text files using the Whisper model.
# It's mainly used to generate the training data for the VQ model.
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
backend = "ffmpeg"
else:
backend = "soundfile"
RANK = int(os.environ.get("SLURM_PROCID", 0))
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
logger_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
"{extra[rank]} - <level>{message}</level>"
)
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
logger.remove()
logger.add(sys.stderr, format=logger_format)
@lru_cache(maxsize=1)
def get_model(
config_name: str = "firefly_gan_vq",
checkpoint_path: str = "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
device: str | torch.device = "cuda",
):
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
cfg = compose(config_name=config_name)
model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path,
map_location=device,
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
logger.info(f"Loaded model")
return model
@torch.inference_mode()
def process_batch(files: list[Path], model) -> float:
wavs = []
audio_lengths = []
new_files = []
max_length = total_time = 0
for file in files:
try:
wav, sr = torchaudio.load(
str(file), backend=backend
) # Need to install libsox-dev
except Exception as e:
logger.error(f"Error reading {file}: {e}")
continue
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = torchaudio.functional.resample(
wav.cuda(), sr, model.spec_transform.sample_rate
)[0]
total_time += len(wav) / model.spec_transform.sample_rate
max_length = max(max_length, len(wav))
wavs.append(wav)
audio_lengths.append(len(wav))
new_files.append(file)
files = new_files
# Pad to max length
for i, wav in enumerate(wavs):
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
audios = torch.stack(wavs, dim=0)[:, None]
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
# Calculate lengths
indices, feature_lengths = model.encode(audios, audio_lengths)
# Save to disk
outputs = indices.cpu().numpy()
for file, length, feature, audio_length in zip(
files, feature_lengths, outputs, audio_lengths
):
feature = feature[:, :length]
# (T,)
with open(file.with_suffix(".npy"), "wb") as f:
np.save(f, feature)
return total_time
@click.command()
@click.argument("folder")
@click.option("--num-workers", default=1)
@click.option("--config-name", default="firefly_gan_vq")
@click.option(
"--checkpoint-path",
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
@click.option("--batch-size", default=64)
@click.option("--filelist", default=None, type=Path)
def main(
folder: str,
num_workers: int,
config_name: str,
checkpoint_path: str,
batch_size: int,
filelist: Path,
):
if num_workers > 1 and WORLD_SIZE != num_workers:
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
logger.info(f"Spawning {num_workers} workers")
if torch.cuda.is_available():
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is None:
visible_devices = list(range(torch.cuda.device_count()))
else:
visible_devices = visible_devices.split(",")
else:
# Set to empty string to avoid using GPU
visible_devices = [""]
processes = []
for i in range(num_workers):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
env["SLURM_PROCID"] = str(i)
env["SLURM_NTASKS"] = str(num_workers)
processes.append(
sp.Popen(
[sys.executable] + sys.argv.copy(),
env=env,
)
)
for p in processes:
p.wait()
logger.info(f"All workers finished")
return
# This is a worker
logger.info(f"Starting worker")
if filelist:
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
print(f"Found {len(files)} files")
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
total_files = len(files)
files = files[RANK::WORLD_SIZE]
logger.info(f"Processing {len(files)}/{total_files} files")
# Batch processing
total_time = 0
begin_time = time.time()
processed_files = 0
model = get_model(config_name, checkpoint_path)
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
batch = files[idx : idx + batch_size]
batch_time = process_batch(batch, model)
total_time += batch_time
processed_files += len(batch)
if (n_batch + 1) % 10 == 0:
eta = (
(time.time() - begin_time)
/ processed_files
* (len(files) - processed_files)
)
logger.info(
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ f"ETA: {timedelta(seconds=round(eta))}s"
)
logger.info(
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
)
if __name__ == "__main__":
main()
import os
import subprocess
import sys
#!/usr/bin/env python
def main():
# Make path relative to this file
script_path = os.path.join(
os.path.dirname(__file__), "../../fish_speech/models/vqgan/inference.py"
)
subprocess.run(["python", script_path] + sys.argv[1:])
if __name__ == "__main__":
main()
from typing import Callable
import gradio as gr
from fish_speech.i18n import i18n
from fish_speech.inference_engine.utils import normalize_text
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
% theme,
)
# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
refined_text = gr.Textbox(
label=i18n("Realtime Transform Text"),
placeholder=i18n(
"Normalization Result Preview (Currently Only Chinese)"
),
lines=5,
interactive=False,
)
with gr.Row():
normalize = gr.Checkbox(
label=i18n("Text Normalization"),
value=False,
)
with gr.Row():
with gr.Column():
with gr.Tab(label=i18n("Advanced Config")):
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
minimum=0,
maximum=300,
value=200,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
"Maximum tokens per batch, 0 means no limit"
),
minimum=0,
maximum=2048,
value=0,
step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
maximum=1.5,
value=1.2,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
seed = gr.Number(
label="Seed",
info="0 means randomized inference, otherwise deterministic",
value=0,
)
with gr.Tab(label=i18n("Reference Audio")):
with gr.Row():
gr.Markdown(
i18n(
"5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
with gr.Row():
reference_id = gr.Textbox(
label=i18n("Reference ID"),
placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
choices=["on", "off"],
value="on",
)
with gr.Row():
reference_audio = gr.Audio(
label=i18n("Reference Audio"),
type="filepath",
)
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
value="",
)
with gr.Column(scale=3):
with gr.Row():
error = gr.HTML(
label=i18n("Error Message"),
visible=True,
)
with gr.Row():
audio = gr.Audio(
label=i18n("Generated Audio"),
type="numpy",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001F3A7 " + i18n("Generate"),
variant="primary",
)
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
# Submit
generate.click(
inference_fct,
[
refined_text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
],
[audio, error],
concurrency_limit=1,
)
return app
import html
from functools import partial
from typing import Any, Callable
from fish_speech.i18n import i18n
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
def inference_wrapper(
text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
):
"""
Wrapper for the inference function.
Used in the Gradio interface.
"""
if reference_audio:
references = get_reference_audio(reference_audio, reference_text)
else:
references = []
req = ServeTTSRequest(
text=text,
normalize=normalize,
reference_id=reference_id if reference_id else None,
references=references,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
seed=int(seed) if seed else None,
use_memory_cache=use_memory_cache,
)
for result in engine.inference(req):
match result.code:
case "final":
return result.audio, None
case "error":
return None, build_html_error_message(i18n(result.error))
case _:
pass
return None, i18n("No audio generated")
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
"""
Get the reference audio bytes.
"""
with open(reference_audio, "rb") as audio_file:
audio_bytes = audio_file.read()
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
def build_html_error_message(error: Any) -> str:
error = error if isinstance(error, Exception) else Exception("Unknown error")
return f"""
<div style="color: red;
font-weight: bold;">
{html.escape(str(error))}
</div>
"""
def get_inference_wrapper(engine) -> Callable:
"""
Get the inference function with the immutable arguments.
"""
return partial(
inference_wrapper,
engine=engine,
)
from fish_speech.i18n import i18n
HEADER_MD = f"""# Fish Speech
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
"""
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
"""
Used to transcribe all audio files in one folder into another folder.
e.g.
Directory structure:
--pre_data_root
----SP_1
------01.wav
------02.wav
------......
----SP_2
------01.wav
------02.wav
------......
Use
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
to transcribe the first speaker.
Use
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
to transcribe the second speaker.
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
"""
import re
from pathlib import Path
import click
import soundfile as sf
from faster_whisper import WhisperModel
from loguru import logger
from pydub import AudioSegment
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
@click.command()
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
@click.option(
"--compute-type",
default="float16",
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
)
@click.option("--audio-dir", required=True, help="Directory containing audio files")
@click.option(
"--save-dir", required=True, help="Directory to save processed audio files"
)
@click.option(
"--sample-rate",
default=44100,
type=int,
help="Output sample rate, default to input sample rate",
)
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
@click.option("--language", default="auto", help="Language of the transcription")
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
def main(
model_size,
compute_type,
audio_dir,
save_dir,
sample_rate,
device,
language,
initial_prompt,
):
logger.info("Loading / Downloading Faster Whisper model...")
model = WhisperModel(
model_size,
device=device,
compute_type=compute_type,
download_root="faster_whisper",
)
logger.info("Model loaded.")
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
audio_files = list_files(
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
for file_path in tqdm(audio_files, desc="Processing audio file"):
file_stem = file_path.stem
file_suffix = file_path.suffix
rel_path = Path(file_path).relative_to(audio_dir)
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
audio = AudioSegment.from_file(file_path)
segments, info = model.transcribe(
file_path,
beam_size=5,
language=None if language == "auto" else language,
initial_prompt=initial_prompt,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
print("Total len(ms): ", len(audio))
whole_text = None
for segment in segments:
id, start, end, text = (
segment.id,
segment.start,
segment.end,
segment.text,
)
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
if not whole_text:
whole_text = text
else:
whole_text += ", " + text
whole_text += "."
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
audio.export(audio_save_path, format=file_suffix[1:])
print(f"Exported {audio_save_path}")
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
with open(
transcript_save_path,
"w",
encoding="utf-8",
) as f:
f.write(whole_text)
if __name__ == "__main__":
main()
exit(0)
audio = AudioSegment.from_wav(
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
)
model_size = "large-v3"
model = WhisperModel(
model_size,
device="cuda",
compute_type="float16",
download_root="faster_whisper",
)
segments, info = model.transcribe(
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
beam_size=5,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
print("Total len(ms): ", len(audio))
for i, segment in enumerate(segments):
print(
"Segment %03d [%.2fs -> %.2fs] %s"
% (i, segment.start, segment.end, segment.text)
)
start_ms = int(segment.start * 1000)
end_ms = int(segment.end * 1000)
segment_audio = audio[start_ms:end_ms]
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
print(f"Exported segment_{i:03d}.wav")
print("All segments have been exported.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment