Commit 033f82a9 authored by guobj's avatar guobj
Browse files

init

2025/04/10 15:55:52
parent ef72564b
from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
import torch
class KModel(torch.nn.Module):
'''
KModel is a torch.nn.Module with 2 main responsibilities:
1. Init weights, downloading config.json + model.pth from HF if needed
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
You likely only need one KModel instance, and it can be reused across
multiple KPipelines to avoid redundant memory allocation.
Unlike KPipeline, KModel is language-blind.
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
so there is no need to repeatedly download config.json outside of KModel.
'''
MODEL_NAMES = {
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
}
def __init__(
self,
repo_id: Optional[str] = None,
config: Union[Dict, str, None] = None,
model: Optional[str] = None,
disable_complex: bool = False
):
super().__init__()
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
if not isinstance(config, dict):
if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=repo_id, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
logger.debug(f"Loaded config: {config}")
self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.context_length = self.bert.config.max_position_embeddings
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
)
if not model:
model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
logger.debug(f"Did not load {key} from state_dict")
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@property
def device(self):
return self.bert.device
@dataclass
class Output:
audio: torch.FloatTensor
pred_dur: Optional[torch.LongTensor] = None
@torch.no_grad()
def forward_with_tokens(
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
input_lengths = torch.full(
(input_ids.shape[0],),
input_ids.shape[-1],
device=input_ids.device,
dtype=torch.long
)
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
return audio, pred_dur
def forward(
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: float = 1,
return_output: bool = False
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
ref_s = ref_s.to(self.device)
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
audio = audio.squeeze().cpu()
pred_dur = pred_dur.cpu() if pred_dur is not None else None
logger.debug(f"pred_dur: {pred_dur}")
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
class KModelForONNX(torch.nn.Module):
def __init__(self, kmodel: KModel):
super().__init__()
self.kmodel = kmodel
def forward(
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
return waveform, duration
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
from .istftnet import AdainResBlk1d
from torch.nn.utils import weight_norm
from transformers import AlbertModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearNorm(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TextEncoder(nn.Module):
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
super().__init__()
self.embedding = nn.Embedding(n_symbols, channels)
padding = (kernel_size - 1) // 2
self.cnn = nn.ModuleList()
for _ in range(depth):
self.cnn.append(nn.Sequential(
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
LayerNorm(channels),
actv,
nn.Dropout(0.2),
))
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
def forward(self, x, input_lengths, m):
x = self.embedding(x) # [B, T, emb]
x = x.transpose(1, 2) # [B, emb, T]
m = m.unsqueeze(1)
x.masked_fill_(m, 0.0)
for c in self.cnn:
x = c(x)
x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # [B, T, chn]
lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
x_pad[:, :, :x.shape[-1]] = x
x = x_pad
x.masked_fill_(m, 0.0)
return x
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class ProsodyPredictor(nn.Module):
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
super().__init__()
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.duration_proj = LinearNorm(d_hid, max_dur)
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.F0 = nn.ModuleList()
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.N = nn.ModuleList()
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
def forward(self, texts, style, text_lengths, alignment, m):
d = self.text_encoder(texts, style, text_lengths, m)
m = m.unsqueeze(1)
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
x_pad[:, :x.shape[1], :] = x
x = x_pad
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
en = (d.transpose(-1, -2) @ alignment)
return duration.squeeze(-1), en
def F0Ntrain(self, x, s):
x, _ = self.shared(x.transpose(-1, -2))
F0 = x.transpose(-1, -2)
for block in self.F0:
F0 = block(F0, s)
F0 = self.F0_proj(F0)
N = x.transpose(-1, -2)
for block in self.N:
N = block(N, s)
N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1)
class DurationEncoder(nn.Module):
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
super().__init__()
self.lstms = nn.ModuleList()
for _ in range(nlayers):
self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
self.dropout = dropout
self.d_model = d_model
self.sty_dim = sty_dim
def forward(self, x, style, text_lengths, m):
masks = m
x = x.permute(2, 0, 1)
s = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, s], axis=-1)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1)
x = x.transpose(-1, -2)
for block in self.lstms:
if isinstance(block, AdaLayerNorm):
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
else:
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
x = x.transpose(-1, -2)
x = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False)
block.flatten_parameters()
x, _ = block(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x = F.dropout(x, p=self.dropout, training=False)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
x_pad[:, :, :x.shape[-1]] = x
x = x_pad
return x.transpose(-1, -2)
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
return outputs.last_hidden_state
from .model import KModel
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak
from typing import Callable, Generator, List, Optional, Tuple, Union
import re
import torch
ALIASES = {
'en-us': 'a',
'en-gb': 'b',
'es': 'e',
'fr-fr': 'f',
'hi': 'h',
'it': 'i',
'pt-br': 'p',
'ja': 'j',
'zh': 'z',
}
LANG_CODES = dict(
# pip install misaki[en]
a='American English',
b='British English',
# espeak-ng
e='es',
f='fr-fr',
h='hi',
i='it',
p='pt-br',
# pip install misaki[ja]
j='Japanese',
# pip install misaki[zh]
z='Mandarin Chinese',
)
class KPipeline:
'''
KPipeline is a language-aware support class with 2 main responsibilities:
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
2. Manage and store voices, lazily downloaded from HF if needed
You are expected to have one KPipeline per language. If you have multiple
KPipelines, you should reuse one KModel instance across all of them.
KPipeline is designed to work with a KModel, but this is not required.
There are 2 ways to pass an existing model into a pipeline:
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
2. On call: us_pipeline(text, voice, model=model)
By default, KPipeline will automatically initialize its own KModel. To
suppress this, construct a "quiet" KPipeline with model=False.
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
any audio. You can use this to phonemize and chunk your text in advance.
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
'''
def __init__(
self,
lang_code: str,
repo_id: Optional[str] = None,
model: Union[KModel, bool] = True,
trf: bool = False,
en_callable: Optional[Callable[[str], str]] = None,
device: Optional[str] = None
):
"""Initialize a KPipeline.
Args:
lang_code: Language code for G2P processing
model: KModel instance, True to create new model, False for no model
trf: Whether to use transformer-based G2P
device: Override default device selection ('cuda' or 'cpu', or None for auto)
If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error
"""
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code
self.model = None
if isinstance(model, KModel):
self.model = model
elif model:
if device == 'cuda' and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available")
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
try:
self.model = KModel(repo_id=repo_id).to(device).eval()
except RuntimeError as e:
if device == 'cuda':
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
Try setting device='cpu' or check CUDA installation.""")
raise
self.voices = {}
if lang_code in 'ab':
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e:
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)})
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
elif lang_code == 'j':
try:
from misaki import ja
self.g2p = ja.JAG2P()
except ImportError:
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
raise
elif lang_code == 'z':
try:
from misaki import zh
self.g2p = zh.ZHG2P(
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
en_callable=en_callable
)
except ImportError:
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise
else:
language = LANG_CODES[lang_code]
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language)
def load_single_voice(self, voice: str):
if voice in self.voices:
return self.voices[voice]
if voice.endswith('.pt'):
f = voice
else:
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code)
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
pack = torch.load(f, weights_only=True)
self.voices[voice] = pack
return pack
"""
load_voice is a helper function that lazily downloads and loads a voice:
Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
If multiple voices are requested, they are averaged.
Delimiter is optional and defaults to ','.
"""
def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
if isinstance(voice, torch.FloatTensor):
return voice
if voice in self.voices:
return self.voices[voice]
logger.debug(f"Loading voice: {voice}")
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
if len(packs) == 1:
return packs[0]
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice]
@staticmethod
def tokens_to_ps(tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@staticmethod
def waterfall_last(
tokens: List[en.MToken],
next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'],
bumps: List[str] = [')', '”']
) -> int:
for w in waterfall:
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
if z is None:
continue
z += 1
if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
return z
return len(tokens)
@staticmethod
def tokens_to_text(tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize(
self,
tokens: List[en.MToken]
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
tks = []
pcount = 0
for t in tokens:
# American English: ɾ => T
t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
next_ps = t.phonemes + (' ' if t.whitespace else '')
next_pcount = pcount + len(next_ps.rstrip())
if next_pcount > 510:
z = KPipeline.waterfall_last(tks, next_pcount)
text = KPipeline.tokens_to_text(tks[:z])
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
ps = KPipeline.tokens_to_ps(tks[:z])
yield text, ps, tks[:z]
tks = tks[z:]
pcount = len(KPipeline.tokens_to_ps(tks))
if not tks:
next_ps = next_ps.lstrip()
tks.append(t)
pcount += len(next_ps)
if tks:
text = KPipeline.tokens_to_text(tks)
ps = KPipeline.tokens_to_ps(tks)
yield ''.join(text).strip(), ''.join(ps).strip(), tks
@staticmethod
def infer(
model: KModel,
ps: str,
pack: torch.FloatTensor,
speed: Union[float, Callable[[int], float]] = 1
) -> KModel.Output:
if callable(speed):
speed = speed(len(ps))
return model(ps, pack[len(ps)-1], speed, return_output=True)
def generate_from_tokens(
self,
tokens: Union[str, List[en.MToken]],
voice: str,
speed: float = 1,
model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]:
"""Generate audio from either raw phonemes or pre-processed tokens.
Args:
tokens: Either a phoneme string or list of pre-processed MTokens
voice: The voice to use for synthesis
speed: Speech speed modifier (default: 1)
model: Optional KModel instance (uses pipeline's model if not provided)
Yields:
KPipeline.Result containing the input tokens and generated audio
Raises:
ValueError: If no voice is provided or token sequence exceeds model limits
"""
model = model or self.model
if model and voice is None:
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
# Handle raw phoneme string
if isinstance(tokens, str):
logger.debug("Processing phonemes from raw string")
if len(tokens) > 510:
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
output = KPipeline.infer(model, tokens, pack, speed) if model else None
yield self.Result(graphemes='', phonemes=tokens, output=output)
return
logger.debug("Processing MTokens")
# Handle pre-processed tokens
for gs, ps, tks in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
logger.warning("Truncating to 510 characters")
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
@staticmethod
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80
MAGIC_DIVISOR = 80
if not tokens or len(pred_dur) < 3:
# We expect at least 3: <bos>, token, <eos>
return
# We track 2 counts, measured in half-frames: (left, right)
# This way we can cut space characters in half
# TODO: Is -3 an appropriate offset?
left = right = 2 * max(0, pred_dur[0].item() - 3)
# Updates:
# left = right + (2 * token_dur) + space_dur
# right = left + space_dur
i = 1
for t in tokens:
if i >= len(pred_dur)-1:
break
if not t.phonemes:
if t.whitespace:
i += 1
left = right + pred_dur[i].item()
right = left + pred_dur[i].item()
i += 1
continue
j = i + len(t.phonemes)
if j >= len(pred_dur):
break
t.start_ts = left / MAGIC_DIVISOR
token_dur = pred_dur[i: j].sum().item()
space_dur = pred_dur[j].item() if t.whitespace else 0
left = right + (2 * token_dur) + space_dur
t.end_ts = left / MAGIC_DIVISOR
right = left + space_dur
i = j + (1 if t.whitespace else 0)
@dataclass
class Result:
graphemes: str
phonemes: str
tokens: Optional[List[en.MToken]] = None
output: Optional[KModel.Output] = None
text_index: Optional[int] = None
@property
def audio(self) -> Optional[torch.FloatTensor]:
return None if self.output is None else self.output.audio
@property
def pred_dur(self) -> Optional[torch.LongTensor]:
return None if self.output is None else self.output.pred_dur
### MARK: BEGIN BACKWARD COMPAT ###
def __iter__(self):
yield self.graphemes
yield self.phonemes
yield self.audio
def __getitem__(self, index):
return [self.graphemes, self.phonemes, self.audio][index]
def __len__(self):
return 3
#### MARK: END BACKWARD COMPAT ####
def __call__(
self,
text: Union[str, List[str]],
voice: Optional[str] = None,
speed: Union[float, Callable[[int], float]] = 1,
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]:
model = model or self.model
if model and voice is None:
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
# Convert input to list of segments
if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
# Process each segment
for graphemes_index, graphemes in enumerate(text):
if not graphemes.strip(): # Skip empty segments
continue
# English processing (unchanged)
if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes)
for gs, ps, tks in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
# Non-English processing with chunking
else:
# Split long text into smaller chunks (roughly 400 characters each)
# Using sentence boundaries when possible
chunk_size = 400
chunks = []
# Try to split on sentence boundaries first
sentences = re.split(r'([.!?]+)', graphemes)
current_chunk = ""
for i in range(0, len(sentences), 2):
sentence = sentences[i]
# Add the punctuation back if it exists
if i + 1 < len(sentences):
sentence += sentences[i + 1]
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
# If no chunks were created (no sentence boundaries), fall back to character-based chunking
if not chunks:
chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
# Process each chunk
for chunk in chunks:
if not chunk.strip():
continue
ps, _ = self.g2p(chunk)
if not ps:
continue
elif len(ps) > 510:
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "kokoro"
version = "0.9.4"
description = "TTS"
readme = "README.md"
authors = [
{ name="hexgrad", email="hello@hexgrad.com" }
]
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent"
]
requires-python = ">=3.10, <3.13"
dependencies = [
"huggingface_hub",
"loguru",
"misaki[en]>=0.9.4",
"numpy",
"torch",
"transformers"
]
[project.scripts]
kokoro = "kokoro.__main__:main"
[tool.hatch.build.targets.wheel]
only-include = ["kokoro"]
only-packages = true
[project.urls]
Homepage = "https://github.com/hexgrad/kokoro"
Repository = "https://github.com/hexgrad/kokoro"
from kokoro import KPipeline
from IPython.display import display, Audio
import soundfile as sf
import torch
pipeline = KPipeline(lang_code='j')
text = '''
[Kokoro](/kˈOkəɹO/) is an open-weight TTS model with 82 million parameters. Despite its lightweight architecture, it delivers comparable quality to larger models while being significantly faster and more cost-efficient. With Apache-licensed weights, [Kokoro](/kˈOkəɹO/) can be deployed anywhere from production environments to personal projects.
'''
generator = pipeline(text, voice='af_heart')
for i, (gs, ps, audio) in enumerate(generator):
print(i, gs, ps)
display(Audio(data=audio, rate=24000, autoplay=i==0))
sf.write(f'{i}.wav', audio, 24000)
import torch
import numpy as np
import pytest
from kokoro.custom_stft import CustomSTFT
from kokoro.istftnet import TorchSTFT
import torch.nn.functional as F
@pytest.fixture
def sample_audio():
# Generate a sample audio signal (sine wave)
sample_rate = 16000
duration = 1.0 # seconds
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0 # Hz
signal = torch.sin(2 * np.pi * frequency * t)
return signal.unsqueeze(0) # Add batch dimension
def test_stft_reconstruction(sample_audio):
# Initialize both STFT implementations
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Process through both implementations
custom_output = custom_stft(sample_audio)
torch_output = torch_stft(sample_audio)
# Compare outputs
assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)
def test_magnitude_phase_consistency(sample_audio):
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Get magnitude and phase from both implementations
custom_mag, custom_phase = custom_stft.transform(sample_audio)
torch_mag, torch_phase = torch_stft.transform(sample_audio)
# Compare magnitudes ignoring the boundary frames
custom_mag_center = custom_mag[..., 2:-2]
torch_mag_center = torch_mag[..., 2:-2]
assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)
def test_batch_processing():
# Create a batch of signals
batch_size = 4
sample_rate = 16000
duration = 0.1 # shorter duration for faster testing
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0
signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
# Process batch
output = custom_stft(signals)
# Check output shape
assert output.shape[0] == batch_size
assert len(output.shape) == 3 # (batch, 1, time)
def test_different_window_sizes():
signal = torch.randn(1, 16000) # 1 second of random noise
# Test with different window sizes
for filter_length in [512, 1024, 2048]:
custom_stft = CustomSTFT(
filter_length=filter_length,
hop_length=filter_length // 4,
win_length=filter_length,
)
# Forward and backward transform
output = custom_stft(signal)
# Check that output length is reasonable
assert output.shape[-1] >= signal.shape[-1]
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment