Commit b75857fb authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from lightning.pytorch.utilities import rank_zero_only
from fish_speech.utils import logger as log
@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
"""Controls which config parts are saved by lightning loggers.
Additionally saves:
- Number of model parameters
"""
hparams = {}
cfg = object_dict["cfg"]
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
from fish_speech.utils import logger as log
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
cfg (DictConfig): Configuration composed by Hydra.
print_order (Sequence[str], optional): Determines in what order config components are printed.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
save_to_file (bool, optional): Whether to export config to the hydra output folder.
""" # noqa: E501
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
(
queue.append(field)
if field in cfg
else log.warning(
f"Field '{field}' not found in config. "
+ f"Skipping '{field}' config printing..."
)
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)
import base64
import os
import queue
from dataclasses import dataclass
from typing import Literal
import torch
from pydantic import BaseModel, Field, conint, conlist, model_validator
from pydantic.functional_validators import SkipValidation
from typing_extensions import Annotated
from fish_speech.conversation import Message, TextPart, VQPart
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
codes: SkipValidation[list[list[int]]]
class ServeTextPart(BaseModel):
type: Literal["text"] = "text"
text: str
class ServeAudioPart(BaseModel):
type: Literal["audio"] = "audio"
audio: bytes
@dataclass
class ASRPackRequest:
audio: torch.Tensor
result_queue: queue.Queue
language: str
class ServeASRRequest(BaseModel):
# The audio should be an uncompressed PCM float16 audio
audios: list[bytes]
sample_rate: int = 44100
language: Literal["zh", "en", "ja", "auto"] = "auto"
class ServeASRTranscription(BaseModel):
text: str
duration: float
huge_gap: bool
class ServeASRSegment(BaseModel):
text: str
start: float
end: float
class ServeTimedASRResponse(BaseModel):
text: str
segments: list[ServeASRSegment]
duration: float
class ServeASRResponse(BaseModel):
transcriptions: list[ServeASRTranscription]
class ServeMessage(BaseModel):
role: Literal["system", "assistant", "user"]
parts: list[ServeVQPart | ServeTextPart]
def to_conversation_message(self):
new_message = Message(role=self.role, parts=[])
if self.role == "assistant":
new_message.modality = "voice"
for part in self.parts:
if isinstance(part, ServeTextPart):
new_message.parts.append(TextPart(text=part.text))
elif isinstance(part, ServeVQPart):
new_message.parts.append(
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
)
else:
raise ValueError(f"Unsupported part type: {part}")
return new_message
class ServeChatRequest(BaseModel):
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
max_new_tokens: int = 1024
top_p: float = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
streaming: bool = False
num_samples: int = 1
early_stop_threshold: float = 1.0
class ServeVQGANEncodeRequest(BaseModel):
# The audio here should be in wav, mp3, etc
audios: list[bytes]
class ServeVQGANEncodeResponse(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeRequest(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeResponse(BaseModel):
# The audio here should be in PCM float16 format
audios: list[bytes]
class ServeForwardMessage(BaseModel):
role: str
content: str
class ServeResponse(BaseModel):
messages: list[ServeMessage]
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] = {}
class ServeStreamDelta(BaseModel):
role: Literal["system", "assistant", "user"] | None = None
part: ServeVQPart | ServeTextPart | None = None
class ServeStreamResponse(BaseModel):
sample_id: int = 0
delta: ServeStreamDelta | None = None
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] | None = None
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
@model_validator(mode="before")
def decode_audio(cls, values):
audio = values.get("audio")
if (
isinstance(audio, str) and len(audio) > 255
): # Check if audio is a string (Base64)
try:
values["audio"] = base64.b64decode(audio)
except Exception as e:
# If the audio is not a valid base64 string, we will just ignore it and let the server handle it
pass
return values
def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on", "off"] = "off"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
class Config:
# Allow arbitrary types for pytorch related types
arbitrary_types_allowed = True
import torch
import torchaudio.functional as F
from torch import Tensor, nn
from torchaudio.transforms import MelScale
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.return_complex = True
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=self.return_complex,
)
if self.return_complex:
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or float(sample_rate // 2)
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
fb = F.melscale_fbanks(
n_freqs=self.n_fft // 2 + 1,
f_min=self.f_min,
f_max=self.f_max,
n_mels=self.n_mels,
sample_rate=self.sample_rate,
norm="slaney",
mel_scale="slaney",
)
self.register_buffer(
"fb",
fb,
persistent=False,
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def apply_mel_scale(self, x: Tensor) -> Tensor:
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
def forward(
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
) -> Tensor:
if sample_rate is not None and sample_rate != self.sample_rate:
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
linear = self.spectrogram(x)
x = self.apply_mel_scale(linear)
x = self.compress(x)
if return_linear:
return x, self.compress(linear)
return x
import random
import warnings
from importlib.util import find_spec
from typing import Callable
import numpy as np
import torch
from omegaconf import DictConfig
from .logger import RankedLogger
from .rich_utils import enforce_tags, print_config_tree
log = RankedLogger(__name__, rank_zero_only=True)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
...
return metric_dict, object_dict
```
""" # noqa: E501
def wrap(cfg: DictConfig):
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or
# cause out-of-memory errors so when using hparam search
# plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.run_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def set_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
icon.png

64.4 KB

#  从语音生成 prompt
python fish_speech/models/vqgan/inference.py \
-i "example.wav" \
--checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
# 从文本生成语义 token
python fish_speech/models/text2semantic/inference.py \
--text "富人优先考虑的都是利益,而穷人优先考虑的永远都是感情和面子,穷人是小心翼翼的大方,而富人却是大大方方的小气。 " \
--prompt-text "The text corresponding to reference audio" \
--prompt-tokens "fake.npy" \
--checkpoint-path "checkpoints/fish-speech-1.5" \
--num-samples 1 \
# --compile
# 从语义 token 生成人声
python fish_speech/models/vqgan/inference.py \
-i "temp/codes_0.npy" \
--checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fish Speech"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### For Windows User / win用户"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "bat"
}
},
"outputs": [],
"source": [
"!chcp 65001"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### For Linux User / Linux 用户"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import locale\n",
"locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For Chinese users, you probably want to use mirror to accelerate downloading\n",
"# !set HF_ENDPOINT=https://hf-mirror.com\n",
"# !export HF_ENDPOINT=https://hf-mirror.com \n",
"\n",
"!huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## WebUI Inference\n",
"\n",
"> You can use --compile to fuse CUDA kernels for faster inference (10x)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python tools/run_webui.py \\\n",
" --llama-checkpoint-path checkpoints/fish-speech-1.5 \\\n",
" --decoder-checkpoint-path checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
" # --compile"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Break-down CLI Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Encode reference audio: / 从语音生成 prompt: \n",
"\n",
"You should get a `fake.npy` file.\n",
"\n",
"你应该能得到一个 `fake.npy` 文件."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"## Enter the path to the audio file here\n",
"src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
"\n",
"!python fish_speech/models/vqgan/inference.py \\\n",
" -i {src_audio} \\\n",
" --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
"\n",
"from IPython.display import Audio, display\n",
"audio = Audio(filename=\"fake.wav\")\n",
"display(audio)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n",
"\n",
"> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n",
"\n",
"> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n",
"\n",
"> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n",
"\n",
"> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python fish_speech/models/text2semantic/inference.py \\\n",
" --text \"hello world\" \\\n",
" --prompt-text \"The text corresponding to reference audio\" \\\n",
" --prompt-tokens \"fake.npy\" \\\n",
" --checkpoint-path \"checkpoints/fish-speech-1.5\" \\\n",
" --num-samples 2\n",
" # --compile"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python fish_speech/models/vqgan/inference.py \\\n",
" -i \"codes_0.npy\" \\\n",
" --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
"\n",
"from IPython.display import Audio, display\n",
"audio = Audio(filename=\"fake.wav\")\n",
"display(audio)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@echo off
chcp 65001
set USE_MIRROR=true
echo "USE_MIRROR: %USE_MIRROR%"
setlocal enabledelayedexpansion
cd /D "%~dp0"
set PATH="%PATH%";%SystemRoot%\system32
echo %PATH%
echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
echo.
echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
goto end
)
)
set TMP=%CD%\fishenv
set TEMP=%CD%\fishenv
(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
set INSTALL_DIR=%cd%\fishenv
set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
set INSTALL_ENV_DIR=%cd%\fishenv\env
set PIP_CMD=%cd%\fishenv\env\python -m pip
set PYTHON_CMD=%cd%\fishenv\env\python
set API_FLAG_PATH=%~dp0API_FLAGS.txt
set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
if "!USE_MIRROR!" == "true" (
set MINICONDA_DOWNLOAD_URL=https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
)
set MINICONDA_CHECKSUM=307194e1f12bbeb52b083634e89cc67db4f7980bd542254b43d3309eaf7cb358
set conda_exists=F
call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
if "%ERRORLEVEL%" EQU "0" set conda_exists=T
if "%conda_exists%" == "F" (
echo.
echo Downloading Miniconda...
mkdir "%INSTALL_DIR%" 2>nul
call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe"
if errorlevel 1 (
echo.
echo Failed to download miniconda.
goto end
)
for /f %%a in ('
certutil -hashfile "%INSTALL_DIR%\miniconda_installer.exe" sha256
^| find /i /v " "
^| find /i "%MINICONDA_CHECKSUM%"
') do (
set "hash=%%a"
)
if not defined hash (
echo.
echo Miniconda hash mismatched!
del "%INSTALL_DIR%\miniconda_installer.exe"
goto end
) else (
echo.
echo Miniconda hash matched successfully.
)
echo Downloaded "%CONDA_ROOT_PREFIX%"
start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
call "%CONDA_ROOT_PREFIX%\_conda.exe" --version
if errorlevel 1 (
echo.
echo Cannot install Miniconda.
goto end
) else (
echo.
echo Miniconda Install success.
)
del "%INSTALL_DIR%\miniconda_installer.exe"
)
if not exist "%INSTALL_ENV_DIR%" (
echo.
echo Creating Conda Environment...
if "!USE_MIRROR!" == "true" (
call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ python=3.10
) else (
call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10
)
if errorlevel 1 (
echo.
echo Failed to Create Environment.
goto end
)
)
if not exist "%INSTALL_ENV_DIR%\python.exe" (
echo.
echo Conda Env does not exist.
goto end
)
set PYTHONNOUSERSITE=1
set PYTHONPATH=
set PYTHONHOME=
set "CUDA_PATH=%INSTALL_ENV_DIR%"
set "CUDA_HOME=%CUDA_PATH%"
call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
if errorlevel 1 (
echo.
echo Failed to activate Env.
goto end
) else (
echo.
echo successfully create env.
)
set "HF_ENDPOINT=https://huggingface.co"
set "no_proxy="
if "%USE_MIRROR%"=="true" (
set "HF_ENDPOINT=https://hf-mirror.com"
set "no_proxy=localhost,127.0.0.1,0.0.0.0"
)
echo "HF_ENDPOINT: !HF_ENDPOINT!"
echo "NO_PROXY: !no_proxy!"
if "!USE_MIRROR!" == "true" (
%PIP_CMD% install torch torchvision torchaudio -U --extra-index-url https://mirrors.bfsu.edu.cn/pypi/web/simple
) else (
%PIP_CMD% install torch torchvision torchaudio -U --index-url https://download.pytorch.org/whl/cu121
)
%PIP_CMD% install -e . --upgrade-strategy only-if-needed
call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
"%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^
"2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
endlocal
echo "Environment Check: Success."
:end
pause
goto :EOF
:download_and_install
setlocal
set "WHEEL_FILE=%1"
set "URL=%2"
set "CHKSUM=%3"
:DOWNLOAD
if not exist "%WHEEL_FILE%" (
call curl -Lk "%URL%" --output "%WHEEL_FILE%"
)
for /f "delims=" %%I in ("certutil -hashfile %WHEEL_FILE% SHA256 ^| find /i %CHKSUM%") do (
set "FILE_VALID=true"
)
if not defined FILE_VALID (
echo File checksum does not match, re-downloading...
del "%WHEEL_FILE%"
goto DOWNLOAD
)
echo "OK for %WHEEL_FILE%"
%PIP_CMD% install "%WHEEL_FILE%" --no-warn-script-location
del "%WHEEL_FILE%"
endlocal
goto :EOF
site_name: Fish Speech
site_description: Targeting SOTA TTS solutions.
site_url: https://speech.fish.audio
# Repository
repo_name: fishaudio/fish-speech
repo_url: https://github.com/fishaudio/fish-speech
edit_uri: blob/main/docs
# Copyright
copyright: Copyright &copy; 2023-2024 by Fish Audio
theme:
name: material
favicon: assets/figs/logo-circle.png
language: en
features:
- content.action.edit
- content.action.view
- navigation.tracking
- navigation.footer
# - navigation.tabs
- search
- search.suggest
- search.highlight
- search.share
- content.code.copy
icon:
logo: fontawesome/solid/fish
palette:
# Palette toggle for automatic mode
- media: "(prefers-color-scheme)"
toggle:
icon: material/brightness-auto
name: Switch to light mode
# Palette toggle for light mode
- media: "(prefers-color-scheme: light)"
scheme: default
toggle:
icon: material/brightness-7
name: Switch to dark mode
primary: black
font:
code: Roboto Mono
# Palette toggle for dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
toggle:
icon: material/brightness-4
name: Switch to light mode
primary: black
font:
code: Roboto Mono
nav:
- Introduction: index.md
- Finetune: finetune.md
- Inference: inference.md
- Start Agent: start_agent.md
- Samples: samples.md
# Plugins
plugins:
- search:
separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
lang:
- en
- zh
- ja
- pt
- ko
- i18n:
docs_structure: folder
languages:
- locale: en
name: English
default: true
build: true
- locale: zh
name: 简体中文
build: true
nav:
- 介绍: zh/index.md
- 微调: zh/finetune.md
- 推理: zh/inference.md
- 启动Agent: zh/start_agent.md
- 例子: zh/samples.md
- locale: ja
name: 日本語
build: true
nav:
- Fish Speech の紹介: ja/index.md
- 微調整: ja/finetune.md
- 推論: ja/inference.md
- スタートエージェント: ja/start_agent.md
- サンプル: ja/samples.md
- locale: pt
name: Português (Brasil)
build: true
nav:
- Introdução: pt/index.md
- Ajuste Fino: pt/finetune.md
- Inferência: pt/inference.md
- Agente inicial: pt/start_agent.md
- Amostras: pt/samples.md
- locale: ko
name: 한국어
build: true
nav:
- 소개: ko/index.md
- 파인튜닝: ko/finetune.md
- 추론: ko/inference.md
- 샘플: ko/samples.md
markdown_extensions:
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- admonition
- pymdownx.details
- pymdownx.superfences
- attr_list
- md_in_html
- pymdownx.superfences
extra_css:
- stylesheets/extra.css
extra:
social:
- icon: fontawesome/brands/discord
link: https://discord.gg/Es5qTB9BcN
- icon: fontawesome/brands/docker
link: https://hub.docker.com/r/fishaudio/fish-speech
- icon: fontawesome/brands/qq
link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
homepage: https://speech.fish.audio
# 模型编码
modelCode=1416
# 模型名称
modelName=fish-speech_pytorch
# 模型描述
modelDescription=超高度还原源音色,性能超过F5、CosySense,多种语言克隆效果炸裂!
# 应用场景
appScenario=推理,语音合成,广媒,影视,动漫,医疗,家居,教育
# 框架类型
frameType=pytorch
[project]
name = "fish-speech"
version = "0.1.0"
authors = [
{name = "Lengyue", email = "lengyue@lengyue.me"},
]
description = "Fish Speech"
readme = "README.md"
requires-python = ">=3.10"
keywords = ["TTS", "Speech"]
license = {text = "CC BY-NC-SA 4.0"}
classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"numpy<=1.26.4",
"transformers>=4.45.2",
"datasets==2.18.0",
"lightning>=2.1.0",
"hydra-core>=1.3.2",
"tensorboard>=2.14.1",
"natsort>=8.4.0",
"einops>=0.7.0",
"librosa>=0.10.1",
"rich>=13.5.3",
"gradio>5.0.0",
"wandb>=0.15.11",
"grpcio>=1.58.0",
"kui>=1.6.0",
"uvicorn>=0.30.0",
"loguru>=0.6.0",
"loralib>=0.1.2",
"pyrootutils>=1.0.4",
"vector_quantize_pytorch==1.14.24",
"resampy>=0.4.3",
"einx[torch]==0.2.2",
"zstandard>=0.22.0",
"pydub",
"pyaudio",
"faster_whisper",
"modelscope==1.17.1",
"funasr==1.1.5",
"opencc-python-reimplemented==0.1.7",
"silero-vad",
"ormsgpack",
"tiktoken>=0.8.0",
"pydantic==2.9.2",
"cachetools",
]
[project.optional-dependencies]
stable = [
"torch<=2.4.1",
"torchaudio",
]
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["fish_speech", "tools"]
{
"exclude": [
"data",
"filelists"
]
}
numpy<=1.26.4
transformers>=4.45.2
datasets==2.18.0
lightning>=2.1.0
hydra-core>=1.3.2
tensorboard>=2.14.1
natsort>=8.4.0
einops>=0.7.0
librosa>=0.10.1
rich>=13.5.3
gradio>5.0.0
wandb>=0.15.11
grpcio>=1.58.0
kui>=1.6.0
uvicorn>=0.30.0
loguru>=0.6.0
loralib>=0.1.2
pyrootutils>=1.0.4
vector_quantize_pytorch==1.14.24
resampy>=0.4.3
einx[torch]==0.2.2
zstandard>=0.22.0
pydub
pyaudio
faster_whisper
modelscope==1.17.1
funasr==1.1.5
opencc-python-reimplemented==0.1.7
silero-vad
ormsgpack
tiktoken>=0.8.0
pydantic==2.9.2
cachetools
setuptools-scm
@echo off
chcp 65001
set no_proxy="127.0.0.1, 0.0.0.0, localhost"
setlocal enabledelayedexpansion
cd /D "%~dp0"
set PATH="%PATH%";%SystemRoot%\system32
echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
echo.
echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
goto end
)
)
set TMP=%CD%\fishenv
set TEMP=%CD%\fishenv
(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
set INSTALL_ENV_DIR=%cd%\fishenv\env
set PYTHONNOUSERSITE=1
set PYTHONPATH=%~dp0
set PYTHONHOME=
call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
if errorlevel 1 (
echo.
echo Environment activation failed.
goto end
) else (
echo.
echo Environment activation succeeded.
)
cmd /k "%*"
:end
pause
@echo off
chcp 65001
set USE_MIRROR=true
set PYTHONPATH=%~dp0
set PYTHON_CMD=python
if exist "fishenv" (
set PYTHON_CMD=%cd%\fishenv\env\python
)
set API_FLAG_PATH=%~dp0API_FLAGS.txt
set KMP_DUPLICATE_LIB_OK=TRUE
setlocal enabledelayedexpansion
set "HF_ENDPOINT=https://huggingface.co"
set "no_proxy="
if "%USE_MIRROR%" == "true" (
set "HF_ENDPOINT=https://hf-mirror.com"
set "no_proxy=localhost, 127.0.0.1, 0.0.0.0"
)
echo "HF_ENDPOINT: !HF_ENDPOINT!"
echo "NO_PROXY: !no_proxy!"
echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
echo.
echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
goto end
)
)
%PYTHON_CMD% .\tools\download_models.py
set "API_FLAGS="
set "flags="
if exist "%API_FLAG_PATH%" (
for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do (
set "line=%%a"
if not "!line:~0,1!"=="#" (
set "line=!line: =<SPACE>!"
set "line=!line:\=!"
set "line=!line:<SPACE>= !"
if not "!line!"=="" (
set "API_FLAGS=!API_FLAGS!!line! "
)
)
)
)
if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!"
set "flags="
echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1
if !errorlevel! equ 0 (
echo.
echo Start HTTP API...
set "mode=api"
goto process_flags
)
echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1
if !errorlevel! equ 0 (
echo.
echo Start WebUI Inference...
set "mode=infer"
goto process_flags
)
:process_flags
for %%p in (!API_FLAGS!) do (
if not "%%p"=="--!mode!" (
set "flags=!flags! %%p"
)
)
if not "!flags!"=="" set "flags=!flags:~1!"
echo Debug: flags = !flags!
if "!mode!"=="api" (
%PYTHON_CMD% -m tools.api_server !flags!
) else if "!mode!"=="infer" (
%PYTHON_CMD% -m tools.webui !flags!
)
echo.
echo Next launch the page...
%PYTHON_CMD% fish_speech\webui\manage.py
:end
endlocal
pause
import argparse
import base64
import wave
import ormsgpack
import pyaudio
import requests
from pydub import AudioSegment
from pydub.playback import play
from fish_speech.utils.file import audio_to_bytes, read_ref_text
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
def parse_args():
parser = argparse.ArgumentParser(
description="Send a WAV file and text to a server and receive synthesized audio.",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--url",
"-u",
type=str,
default="http://127.0.0.1:8080/v1/tts",
help="URL of the server",
)
parser.add_argument(
"--text", "-t", type=str, required=True, help="Text to be synthesized"
)
parser.add_argument(
"--reference_id",
"-id",
type=str,
default=None,
help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
)
parser.add_argument(
"--reference_audio",
"-ra",
type=str,
nargs="+",
default=None,
help="Path to the audio file",
)
parser.add_argument(
"--reference_text",
"-rt",
type=str,
nargs="+",
default=None,
help="Reference text for voice synthesis",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="generated_audio",
help="Output audio file name",
)
parser.add_argument(
"--play",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether to play audio after receiving data",
)
parser.add_argument("--normalize", type=bool, default=True)
parser.add_argument(
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
)
parser.add_argument(
"--latency",
type=str,
default="normal",
choices=["normal", "balanced"],
help="Used in api.fish.audio/v1/tts",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="Maximum new tokens to generate. \n0 means no limit.",
)
parser.add_argument(
"--chunk_length", type=int, default=200, help="Chunk length for synthesis"
)
parser.add_argument(
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.2,
help="Repetition penalty for synthesis",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Temperature for sampling"
)
parser.add_argument(
"--streaming", type=bool, default=False, help="Enable streaming response"
)
parser.add_argument(
"--channels", type=int, default=1, help="Number of audio channels"
)
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
parser.add_argument(
"--use_memory_cache",
type=str,
default="off",
choices=["on", "off"],
help="Cache encoded references codes in memory.\n",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="`None` means randomized inference, otherwise deterministic.\n"
"It can't be used for fixing a timbre.",
)
parser.add_argument(
"--api_key",
type=str,
default="YOUR_API_KEY",
help="API key for authentication",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
idstr: str | None = args.reference_id
# priority: ref_id > [{text, audio},...]
if idstr is None:
ref_audios = args.reference_audio
ref_texts = args.reference_text
if ref_audios is None:
byte_audios = []
else:
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
if ref_texts is None:
ref_texts = []
else:
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
else:
byte_audios = []
ref_texts = []
pass # in api.py
data = {
"text": args.text,
"references": [
ServeReferenceAudio(
audio=ref_audio if ref_audio is not None else b"", text=ref_text
)
for ref_text, ref_audio in zip(ref_texts, byte_audios)
],
"reference_id": idstr,
"normalize": args.normalize,
"format": args.format,
"max_new_tokens": args.max_new_tokens,
"chunk_length": args.chunk_length,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"temperature": args.temperature,
"streaming": args.streaming,
"use_memory_cache": args.use_memory_cache,
"seed": args.seed,
}
pydantic_data = ServeTTSRequest(**data)
response = requests.post(
args.url,
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
stream=args.streaming,
headers={
"authorization": f"Bearer {args.api_key}",
"content-type": "application/msgpack",
},
)
if response.status_code == 200:
if args.streaming:
p = pyaudio.PyAudio()
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
stream = p.open(
format=audio_format, channels=args.channels, rate=args.rate, output=True
)
wf = wave.open(f"{args.output}.wav", "wb")
wf.setnchannels(args.channels)
wf.setsampwidth(p.get_sample_size(audio_format))
wf.setframerate(args.rate)
stream_stopped_flag = False
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
stream.write(chunk)
wf.writeframesraw(chunk)
else:
if not stream_stopped_flag:
stream.stop_stream()
stream_stopped_flag = True
finally:
stream.close()
p.terminate()
wf.close()
else:
audio_content = response.content
audio_path = f"{args.output}.{args.format}"
with open(audio_path, "wb") as audio_file:
audio_file.write(audio_content)
audio = AudioSegment.from_file(audio_path, format=args.format)
if args.play:
play(audio)
print(f"Audio has been saved to '{audio_path}'.")
else:
print(f"Request failed with status code {response.status_code}")
print(response.json())
import re
from threading import Lock
import pyrootutils
import uvicorn
from kui.asgi import (
Depends,
FactoryClass,
HTTPException,
HttpRoute,
Kui,
OpenAPI,
Routes,
)
from kui.cors import CORSConfig
from kui.openapi.specification import Info
from kui.security import bearer_auth
from loguru import logger
from typing_extensions import Annotated
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from tools.server.api_utils import MsgPackRequest, parse_args
from tools.server.exception_handler import ExceptionHandler
from tools.server.model_manager import ModelManager
from tools.server.views import routes
class API(ExceptionHandler):
def __init__(self):
self.args = parse_args()
self.routes = routes
def api_auth(endpoint):
async def verify(token: Annotated[str, Depends(bearer_auth)]):
if token != self.args.api_key:
raise HTTPException(401, None, "Invalid token")
return await endpoint()
async def passthrough():
return await endpoint()
if self.args.api_key is not None:
return verify
else:
return passthrough
self.openapi = OpenAPI(
Info(
{
"title": "Fish Speech API",
"version": "1.5.0",
}
),
).routes
# Initialize the app
self.app = Kui(
routes=self.routes + self.openapi[1:], # Remove the default route
exception_handlers={
HTTPException: self.http_exception_handler,
Exception: self.other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config=CORSConfig(),
)
# Add the state variables
self.app.state.lock = Lock()
self.app.state.device = self.args.device
self.app.state.max_text_length = self.args.max_text_length
# Associate the app with the model manager
self.app.on_startup(self.initialize_app)
async def initialize_app(self, app: Kui):
# Make the ModelManager available to the views
app.state.model_manager = ModelManager(
mode=self.args.mode,
device=self.args.device,
half=self.args.half,
compile=self.args.compile,
asr_enabled=self.args.load_asr_model,
llama_checkpoint_path=self.args.llama_checkpoint_path,
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
decoder_config_name=self.args.decoder_config_name,
)
logger.info(f"Startup done, listening server at http://{self.args.listen}")
# Each worker process created by Uvicorn has its own memory space,
# meaning that models and variables are not shared between processes.
# Therefore, any variables (like `llama_queue` or `decoder_model`)
# will not be shared across workers.
# Multi-threading for deep learning can cause issues, such as inconsistent
# outputs if multiple threads access the same buffers simultaneously.
# Instead, it's better to use multiprocessing or independent models per thread.
if __name__ == "__main__":
api = API()
# IPv6 address format is [xxxx:xxxx::xxxx]:port
match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
if match:
host, port = match.groups() # IPv6
else:
host, port = api.args.listen.split(":") # IPv4
uvicorn.run(
api.app,
host=host,
port=int(port),
workers=api.args.workers,
log_level="info",
)
import os
from huggingface_hub import hf_hub_download
# Download
def check_and_download_files(repo_id, file_list, local_dir):
os.makedirs(local_dir, exist_ok=True)
for file in file_list:
file_path = os.path.join(local_dir, file)
if not os.path.exists(file_path):
print(f"{file} 不存在,从 Hugging Face 仓库下载...")
hf_hub_download(
repo_id=repo_id,
filename=file,
resume_download=True,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
else:
print(f"{file} 已存在,跳过下载。")
# 1st
repo_id_1 = "fishaudio/fish-speech-1.5"
local_dir_1 = "./checkpoints/fish-speech-1.5"
files_1 = [
".gitattributes",
"model.pth",
"README.md",
"special_tokens.json",
"tokenizer.tiktoken",
"config.json",
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
]
# 3rd
repo_id_3 = "fishaudio/fish-speech-1"
local_dir_3 = "./"
files_3 = [
"ffmpeg.exe",
"ffprobe.exe",
]
# 4th
repo_id_4 = "SpicyqSama007/fish-speech-packed"
local_dir_4 = "./"
files_4 = [
"asr-label-win-x64.exe",
]
check_and_download_files(repo_id_1, files_1, local_dir_1)
check_and_download_files(repo_id_3, files_3, local_dir_3)
check_and_download_files(repo_id_4, files_4, local_dir_4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment