Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
import gradio as gr
def html_center(text, label='p'):
return f"""<div style="text-align: center; margin: 100; padding: 50;">
<{label} style="margin: 0; padding: 0;">{text}</{label}>
</div>"""
def html_left(text, label='p'):
return f"""<div style="text-align: left; margin: 0; padding: 0;">
<{label} style="margin: 0; padding: 0;">{text}</{label}>
</div>"""
def next_page(page_number,sentences):
new_page_number = int(page_number) + 1
update_page_number = gr.update(value=str(new_page_number))
update_prev_page = gr.update(visible=True, interactive=True)
if len(sentences.values) <= new_page_number * 20:
update_next_page = gr.update(visible=False, interactive=False)
else:
update_next_page = gr.update(visible=True, interactive=True)
return update_page_number, update_next_page, update_prev_page
def prev_page(page_number):
new_page_number = int(page_number) - 1
update_page_number = gr.update(value=str(new_page_number))
if new_page_number == 1:
update_prev_page = gr.update(visible=False, interactive=False)
else:
update_prev_page = gr.update(visible=True, interactive=True)
update_next_page = gr.update(visible=True, interactive=True)
return update_page_number, update_next_page, update_prev_page
def update_current_texts(page_number,sentences):
start_index = (int(page_number) - 1) * 20
end_index = int(page_number) * 20
current_texts = sentences.values[start_index:end_index if end_index < len(sentences.values) else len(sentences.values)]
return gr.update(values=current_texts)
import math
from collections import namedtuple
from functools import partial
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates',
'past_key_values',
])
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class not_equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x != self.val
class equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x == self.val
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
# init helpers
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# activations
class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, 'n d -> () n d')
return pos_emb * self.scale
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, 'n d -> () n d')
class RelativePositionBias(nn.Module):
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)
class AlibiPositionalBias(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> () h () ()')
self.register_buffer('slopes', slopes, persistent=False)
self.register_buffer('bias', None, persistent=False)
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
:heads - closest_power_of_2]
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
if exists(self.bias) and self.bias.shape[-1] >= j:
return qk_dots + self.bias[..., :j]
bias = torch.arange(j, device=device)
bias = rearrange(bias, 'j -> () () () j')
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[1]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent=False)
return qk_dots + self.bias
class LearnedAlibiPositionalBias(AlibiPositionalBias):
def __init__(self, heads, bidirectional=False):
super().__init__(heads)
los_slopes = torch.log(self.slopes)
self.learned_logslopes = nn.Parameter(los_slopes)
self.bidirectional = bidirectional
if self.bidirectional:
self.learned_logslopes_future = nn.Parameter(los_slopes)
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
def get_slopes(param):
return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
if exists(self.bias) and self.bias.shape[-1] >= j:
bias = self.bias[..., :i, :j]
else:
i_arange = torch.arange(i, device=device)
j_arange = torch.arange(j, device=device)
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
self.register_buffer('bias', bias, persistent=False)
if self.bidirectional:
past_slopes = get_slopes(self.learned_logslopes)
future_slopes = get_slopes(self.learned_logslopes_future)
bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
else:
slopes = get_slopes(self.learned_logslopes)
bias = bias * slopes
return qk_dots + bias
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, max_seq_len, device):
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d')
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
seq_len = t.shape[-2]
freqs = freqs[:, :, -seq_len:]
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
# norms
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)
scale_fn = lambda t: t * self.value
if not isinstance(out, tuple):
return scale_fn(out)
return (scale_fn(out[0]), *out[1:])
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)
rezero_fn = lambda t: t * self.g
if not isinstance(out, tuple):
return rezero_fn(out)
return (rezero_fn(out[0]), *out[1:])
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
def forward(self, x, norm_scale_shift_inp):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
norm = x / norm.clamp(min=self.eps) * self.g
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
scale, shift = torch.chunk(ss_emb, 2, dim=1)
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return h
# residual and residual gates
class Residual(nn.Module):
def __init__(self, dim, scale_residual=False):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
def forward(self, x, residual):
if exists(self.residual_scale):
residual = residual * self.residual_scale
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim, scale_residual=False):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
def forward(self, x, residual):
if exists(self.residual_scale):
residual = residual * self.residual_scale
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
return gated_output.reshape_as(x)
# token shifting
def shift(t, amount, mask=None):
if amount == 0:
return t
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
return F.pad(t, (0, 0, amount, -amount), value=0.)
class ShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
mask = kwargs.get('mask', None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim=-1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim=-1)
return self.fn(x, **kwargs)
# feedforward
class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out=None,
mult=4,
glu=False,
relu_squared=False,
post_act_ln=False,
dropout=0.,
zero_init_output=False
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
activation
) if not glu else GLU(dim, inner_dim, activation)
self.net = nn.Sequential(
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
# init last linear layer to 0
if zero_init_output:
init_zero_(self.net[-1])
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
talking_heads=False,
head_scale=False,
collab_heads=False,
collab_compression=.3,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False,
gate_values=False,
zero_init_output=False,
max_attend_past=None,
qk_norm=False,
scale_init_value=None,
rel_pos_bias=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.causal = causal
self.max_attend_past = max_attend_past
qk_dim = v_dim = dim_head * heads
# collaborative heads
self.collab_heads = collab_heads
if self.collab_heads:
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = nn.Linear(dim, qk_dim, bias=False)
self.to_k = nn.Linear(dim, qk_dim, bias=False)
self.to_v = nn.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)
# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
scale_init_value = default(scale_init_value,
-3) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
# head scaling
self.head_scale = head_scale
if head_scale:
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# entmax
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
# init output projection 0
if zero_init_output:
init_zero_(self.to_out)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
attn_mask=None,
sinusoidal_emb=None,
rotary_pos_emb=None,
prev_attn=None,
mem=None,
layer_past=None,
):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
context)
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
if not collab_heads:
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
else:
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
k = rearrange(k, 'b n d -> b () n d')
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
k_cache = k
v_cache = v
if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
if collab_heads:
k = k.expand(-1, h, -1, -1)
if self.qk_norm:
q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min=1e-2))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots.clone()
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
if self.rel_pos_bias:
dots = self.rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if exists(attn_mask):
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
if attn_mask.ndim == 2:
attn_mask = rearrange(attn_mask, 'i j -> () () i j')
elif attn_mask.ndim == 3:
attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
dots.masked_fill_(~attn_mask, mask_value)
if exists(self.max_attend_past):
i, j = dots.shape[-2:]
range_q = torch.arange(j - i, j, device=device)
range_k = torch.arange(j, device=device)
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
mask = dist > self.max_attend_past
dots.masked_fill_(mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn.clone()
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
if head_scale:
out = out * self.head_scale_params
out = rearrange(out, 'b h n d -> b n (h d)')
if exists(self.to_v_gate):
gates = self.to_v_gate(x)
out = out * gates.sigmoid()
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates, k_cache, v_cache
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rms_scaleshift_norm=False,
use_rmsnorm=False,
use_rezero=False,
alibi_pos_bias=False,
alibi_num_heads=None,
alibi_learned=False,
position_infused_attn=False,
rotary_pos_emb=False,
rotary_emb_dim=None,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
scale_residual=False,
shift_tokens=0,
sandwich_norm=False,
use_qk_norm_attn=False,
qk_norm_attn_seq_len=None,
zero_init_branch_output=False,
**kwargs
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.causal = causal
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not (
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
else:
self.rel_pos = None
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
self.cross_attend = cross_attend
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init
if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# calculate layer block order
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
# calculate token shifting
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
# iterate and construct layers
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
is_last_layer = ind == (len(self.layer_types) - 1)
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1
shift_range_lower = -layer_shift_tokens if not causal else 0
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
if exists(branch_fn):
layer = branch_fn(layer)
residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual=scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
norms = nn.ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])
self.layers.append(nn.ModuleList([
norms,
layer,
residual
]))
def forward(
self,
x,
context=None,
full_context=None, # for passing a list of hidden states from an encoder
mask=None,
context_mask=None,
attn_mask=None,
mems=None,
return_hiddens=False,
norm_scale_shift_inp=None,
past_key_values=None,
expected_seq_len=None,
):
assert not (self.cross_attend ^ (exists(context) or exists(
full_context))), 'context must be passed in if cross_attend is set to True'
assert context is None or full_context is None, 'only one of full_context or context can be provided'
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
norm_args = {}
if exists(norm_scale_shift_inp):
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
if not self.training and self.causal:
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
elif expected_seq_len is None:
expected_seq_len = 0
seq_len = x.shape[1]
if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = []
cross_attn_count = 0
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
if layer_type == 'a':
layer_mem = mems.pop(0) if mems else None
residual = x
pre_branch_norm, post_branch_norm, post_main_norm = norm
if exists(pre_branch_norm):
x = pre_branch_norm(x, **norm_args)
if layer_type == 'a' or layer_type == 'c':
if past_key_values is not None:
layer_kv = past_key_values.pop(0)
layer_past = tuple(s.to(x.device) for s in layer_kv)
else:
layer_past = None
if layer_type == 'a':
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past)
elif layer_type == 'c':
if exists(full_context):
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past)
else:
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f':
out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))
if exists(post_branch_norm):
out = post_branch_norm(out, **norm_args)
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if exists(post_main_norm):
x = post_main_norm(x, **norm_args)
if layer_type == 'c':
cross_attn_count += 1
if layer_type == 'f':
hiddens.append(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates,
past_key_values=present_key_values
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal=True, **kwargs)
class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend=True, only_cross=True, **kwargs)
class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
attn_layers,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
def forward(
self,
img,
return_embeddings=False
):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.attn_layers(x)
x = self.norm(x)
if not exists(self.mlp_head) or return_embeddings:
return x
return self.mlp_head(x[:, 0])
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
shift_mem_down=0,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
def init_(self):
nn.init.kaiming_normal_(self.token_emb.weight)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_hiddens=False,
return_attn=False,
mems=None,
use_cache=False,
**kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x = x + self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l]
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_hiddens:
hiddens = intermediates.hiddens
return out, hiddens
res = [out]
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)
if len(res) > 1:
return tuple(res)
return res[0]
class ContinuousTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_seq_len,
attn_layers,
dim_in=None,
dim_out=None,
emb_dim=None,
emb_dropout=0.,
use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
self.max_seq_len = max_seq_len
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_attn=False,
mems=None,
use_cache=False,
**kwargs
):
b, n, _, device = *x.shape, x.device
x = self.project_in(x)
x = x + self.pos_emb(x)
x = self.emb_dropout(x)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
out = self.project_out(x) if not return_embeddings else x
res = [out]
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)
if len(res) > 1:
return tuple(res)
return res[0]
import functools
from math import sqrt
import torch
import torch.distributed as distributed
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from einops import rearrange
def default(val, d):
return val if val is not None else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def dvae_wav_to_mel(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
):
mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=1024,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
norm="slaney",
).to(device)
wav = wav.to(device)
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
self.balancing_heuristic = balancing_heuristic
self.codes = None
self.max_codes = 64000
self.codes_full = False
self.new_return_order = new_return_order
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input, return_soft_codes=False):
if self.balancing_heuristic and self.codes_full:
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1)
ep = self.embed.permute(1, 0)
ea = self.embed_avg.permute(1, 0)
rand_embed = torch.randn_like(ep) * mask
self.embed = (ep * ~mask + rand_embed).permute(1, 0)
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
self.codes = None
self.codes_full = False
flatten = input.reshape(-1, self.dim)
dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True)
soft_codes = -dist
_, embed_ind = soft_codes.max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.balancing_heuristic:
if self.codes is None:
self.codes = embed_ind.flatten()
else:
self.codes = torch.cat([self.codes, embed_ind.flatten()])
if len(self.codes) > self.max_codes:
self.codes = self.codes[-self.max_codes :]
self.codes_full = True
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
elif self.new_return_order:
return quantize, embed_ind, diff
else:
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
# values with the specified expected variance.
class DiscretizationLoss(nn.Module):
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
super().__init__()
self.discrete_bins = discrete_bins
self.dim = dim
self.dist = torch.distributions.Normal(0, scale=expected_variance)
if store_past > 0:
self.record_past = True
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu"))
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu"))
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
else:
self.record_past = False
def forward(self, x):
other_dims = set(range(len(x.shape))) - set([self.dim])
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
averaged = averaged - averaged.mean()
if self.record_past:
acc_count = self.accumulator.shape[0]
avg = averaged.detach().clone()
if self.accumulator_filled > 0:
averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count
# Also push averaged into the accumulator.
self.accumulator[self.accumulator_index] = avg
self.accumulator_index += 1
if self.accumulator_index >= acc_count:
self.accumulator_index *= 0
if self.accumulator_filled <= 0:
self.accumulator_filled += 1
return torch.sum(-self.dist.log_prob(averaged))
class ResBlock(nn.Module):
def __init__(self, chan, conv, activation):
super().__init__()
self.net = nn.Sequential(
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 1),
)
def forward(self, x):
return self.net(x) + x
class UpsampledConv(nn.Module):
def __init__(self, conv, *args, **kwargs):
super().__init__()
assert "stride" in kwargs.keys()
self.stride = kwargs["stride"]
del kwargs["stride"]
self.conv = conv(*args, **kwargs)
def forward(self, x):
up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest")
return self.conv(up)
# DiscreteVAE partially derived from lucidrains DALLE implementation
# Credit: https://github.com/lucidrains/DALLE-pytorch
class DiscreteVAE(nn.Module):
def __init__(
self,
positional_dims=2,
num_tokens=512,
codebook_dim=512,
num_layers=3,
num_resnet_blocks=0,
hidden_dim=64,
channels=3,
stride=2,
kernel_size=4,
use_transposed_convs=True,
encoder_norm=False,
activation="relu",
smooth_l1_loss=False,
straight_through=False,
normalization=None, # ((0.5,) * 3, (0.5,) * 3),
record_codes=False,
discretization_loss_averaging_steps=100,
lr_quantizer_args={},
):
super().__init__()
has_resblocks = num_resnet_blocks > 0
self.num_tokens = num_tokens
self.num_layers = num_layers
self.straight_through = straight_through
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
)
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2:
conv = nn.Conv2d
conv_transpose = nn.ConvTranspose2d
else:
conv = nn.Conv1d
conv_transpose = nn.ConvTranspose1d
if not use_transposed_convs:
conv_transpose = functools.partial(UpsampledConv, conv)
if activation == "relu":
act = nn.ReLU
elif activation == "silu":
act = nn.SiLU
else:
assert NotImplementedError()
enc_layers = []
dec_layers = []
if num_layers > 0:
enc_chans = [hidden_dim * 2**i for i in range(num_layers)]
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(
nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act())
)
dec_out_chans = dec_chans[-1]
innermost_dim = dec_chans[0]
else:
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
dec_out_chans = hidden_dim
innermost_dim = hidden_dim
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
enc_layers.append(ResBlock(innermost_dim, conv, act))
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
# take care of normalization within class
self.normalization = normalization
self.record_codes = record_codes
if record_codes:
self.codes = torch.zeros((1228800,), dtype=torch.long)
self.code_ind = 0
self.total_codes = 0
self.internal_step = 0
def norm(self, images):
if not self.normalization is not None:
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
def get_debug_values(self, step, __):
if self.record_codes and self.total_codes > 0:
# Report annealing schedule
return {"histogram_codes": self.codes[: self.total_codes]}
else:
return {}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes
def decode(self, img_seq):
self.log_codes(img_seq)
if hasattr(self.codebook, "embed_code"):
image_embeds = self.codebook.embed_code(img_seq)
else:
image_embeds = F.embedding(img_seq, self.codebook.codebook)
b, n, d = image_embeds.shape
kwargs = {}
if self.positional_dims == 1:
arrange = "b n d -> b d n"
else:
h = w = int(sqrt(n))
arrange = "b (h w) d -> b d h w"
kwargs = {"h": h, "w": w}
image_embeds = rearrange(image_embeds, arrange, **kwargs)
images = [image_embeds]
for layer in self.decoder:
images.append(layer(images[-1]))
return images[-1], images[-2]
def infer(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
return self.decode(codes)
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
# more lossy (but useful for determining network performance).
def forward(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1))
if self.training:
out = sampled
for d in self.decoder:
out = d(out)
self.log_codes(codes)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out, _ = self.decode(codes)
# reconstruction loss
out = out[..., :img.shape[-1]]
recon_loss = self.loss_fn(img, out, reduction="mean")
ssim_loss = torch.zeros(size=(1,)).cuda()
return recon_loss, ssim_loss, commitment_loss, out
def log_codes(self, codes):
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 10 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i : i + l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.total_codes += 1
self.internal_step += 1
2026-01-07 09:55:53.359 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 09:55:53.372 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 09:55:56.908 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 10:08:12.331 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 10:08:12.340 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 10:08:15.863 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 10:08:18.419 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 10:08:18.781 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 10:08:22.673 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 10:13:42.833 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 10:13:42.842 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 10:13:46.408 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 10:13:48.547 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 10:13:48.892 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 10:13:52.163 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 11:35:01.871 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 11:35:01.882 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 11:35:05.240 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 11:35:07.533 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 11:35:07.892 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 11:35:11.961 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 11:53:45.877 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 11:53:45.889 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 11:53:49.144 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 11:53:51.373 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 11:53:51.762 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 11:53:55.438 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 11:54:32.456 | INFO | indextts.infer_vllm_v2:__init__:176 - >> TextNormalizer loaded
2026-01-07 11:54:32.484 | INFO | indextts.infer_vllm_v2:__init__:178 - >> bpe model loaded from: checkpoints/IndexTTS-2-vLLM/bpe.model
2026-01-07 12:26:15.303 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 12:26:33.498 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 12:26:34.437 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [9e538135d39f4ee1b990f7940806b044] [prefill time: 0.9366]
2026-01-07 12:26:46.639 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [9e538135d39f4ee1b990f7940806b044] [decode time: 12.2017] [decode len: 1761]
2026-01-07 12:26:47.435 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 12:26:47.855 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 12:26:47.874 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [1c403009d06c4c3c9811620525f339aa] [prefill time: 0.0177]
2026-01-07 12:26:59.764 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [1c403009d06c4c3c9811620525f339aa] [decode time: 11.8893] [decode len: 1761]
2026-01-07 12:27:00.146 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 12:27:00.777 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 12:27:00.801 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [d8300dd0b6b749b08352ea5468732605] [prefill time: 0.0174]
2026-01-07 12:27:12.761 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [d8300dd0b6b749b08352ea5468732605] [decode time: 11.9596] [decode len: 1761]
2026-01-07 12:27:13.116 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 12:27:14.408 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 12:27:14.432 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [f0d11d0427cd4ac1a83f89d987d137e3] [prefill time: 0.0176]
2026-01-07 12:27:15.221 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [f0d11d0427cd4ac1a83f89d987d137e3] [decode time: 0.7887] [decode len: 135]
2026-01-07 15:39:20.324 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 15:39:20.332 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 15:39:23.893 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 15:39:26.238 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 15:39:26.591 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 15:39:30.100 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 15:39:32.076 | INFO | indextts.infer_vllm_v2:__init__:176 - >> TextNormalizer loaded
2026-01-07 15:39:32.108 | INFO | indextts.infer_vllm_v2:__init__:178 - >> bpe model loaded from: checkpoints/IndexTTS-2-vLLM/bpe.model
2026-01-07 15:40:39.322 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 15:40:45.193 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 15:40:46.224 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [2c7b38cf2739452c9ebdab36aa5c4db0] [prefill time: 1.0290]
2026-01-07 15:40:48.352 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [2c7b38cf2739452c9ebdab36aa5c4db0] [decode time: 2.1280] [decode len: 336]
2026-01-07 15:40:48.576 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 15:40:48.999 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 15:40:49.018 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [00454b8a816048a998374d8b6dd2df81] [prefill time: 0.0177]
2026-01-07 15:40:53.017 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [00454b8a816048a998374d8b6dd2df81] [decode time: 3.9981] [decode len: 616]
2026-01-07 15:40:53.215 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 15:40:53.674 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 15:40:53.696 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [eced97d4f81d4023a9406bf25f2af326] [prefill time: 0.0200]
2026-01-07 15:40:56.550 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [eced97d4f81d4023a9406bf25f2af326] [decode time: 2.8534] [decode len: 439]
2026-01-07 15:40:56.712 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 15:40:57.999 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 15:40:58.018 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [36cdd0e5c0fa4b31b80e1733fc2aa77b] [prefill time: 0.0171]
2026-01-07 15:40:58.186 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [36cdd0e5c0fa4b31b80e1733fc2aa77b] [decode time: 0.1675] [decode len: 29]
2026-01-07 17:16:04.466 | INFO | indextts.infer_vllm_v2:__init__:105 - >> GPT weights restored from: checkpoints/IndexTTS-2-vLLM/gpt.pth
2026-01-07 17:16:04.473 | INFO | indextts.infer_vllm_v2:__init__:116 - >> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.
2026-01-07 17:16:07.821 | INFO | indextts.infer_vllm_v2:__init__:137 - >> semantic_codec weights restored from: checkpoints/IndexTTS-2-vLLM/semantic_codec/model.safetensors
2026-01-07 17:16:09.645 | INFO | indextts.infer_vllm_v2:__init__:152 - >> s2mel weights restored from: checkpoints/IndexTTS-2-vLLM/s2mel.pth
2026-01-07 17:16:09.974 | INFO | indextts.infer_vllm_v2:__init__:163 - >> campplus_model weights restored from: checkpoints/IndexTTS-2-vLLM/campplus/campplus_cn_common.bin
2026-01-07 17:16:13.027 | INFO | indextts.infer_vllm_v2:__init__:171 - >> bigvgan weights restored from: nvidia/bigvgan_v2_22khz_80band_256x
2026-01-07 17:16:15.973 | INFO | indextts.infer_vllm_v2:__init__:176 - >> TextNormalizer loaded
2026-01-07 17:16:16.002 | INFO | indextts.infer_vllm_v2:__init__:178 - >> bpe model loaded from: checkpoints/IndexTTS-2-vLLM/bpe.model
2026-01-07 17:17:34.294 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 17:17:40.179 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 17:17:41.249 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [22f7281eaa15416488635de965f49159] [prefill time: 1.0658]
2026-01-07 17:17:52.958 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [22f7281eaa15416488635de965f49159] [decode time: 11.7088] [decode len: 1761]
2026-01-07 17:18:06.446 | INFO | indextts.infer_vllm_v2:infer:458 - >> gpt_gen_time: 12.86 seconds
2026-01-07 17:18:06.446 | INFO | indextts.infer_vllm_v2:infer:459 - >> gpt_forward_time: 0.18 seconds
2026-01-07 17:18:06.447 | INFO | indextts.infer_vllm_v2:infer:460 - >> s2mel_time: 11.15 seconds
2026-01-07 17:18:06.448 | INFO | indextts.infer_vllm_v2:infer:461 - >> bigvgan_time: 2.13 seconds
2026-01-07 17:18:06.448 | INFO | indextts.infer_vllm_v2:infer:462 - >> Total inference time: 32.13 seconds
2026-01-07 17:18:06.449 | INFO | indextts.infer_vllm_v2:infer:463 - >> Generated audio length: 35.12 seconds
2026-01-07 17:18:06.449 | INFO | indextts.infer_vllm_v2:infer:464 - >> RTF: 0.9149
2026-01-07 17:18:06.465 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 17:18:06.856 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 17:18:06.875 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [4bee66e70af64239a75fa180dbe00801] [prefill time: 0.0163]
2026-01-07 17:18:18.407 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [4bee66e70af64239a75fa180dbe00801] [decode time: 11.5326] [decode len: 1761]
2026-01-07 17:18:29.940 | INFO | indextts.infer_vllm_v2:infer:458 - >> gpt_gen_time: 11.62 seconds
2026-01-07 17:18:29.941 | INFO | indextts.infer_vllm_v2:infer:459 - >> gpt_forward_time: 0.18 seconds
2026-01-07 17:18:29.942 | INFO | indextts.infer_vllm_v2:infer:460 - >> s2mel_time: 9.59 seconds
2026-01-07 17:18:29.942 | INFO | indextts.infer_vllm_v2:infer:461 - >> bigvgan_time: 0.08 seconds
2026-01-07 17:18:29.942 | INFO | indextts.infer_vllm_v2:infer:462 - >> Total inference time: 23.45 seconds
2026-01-07 17:18:29.942 | INFO | indextts.infer_vllm_v2:infer:463 - >> Generated audio length: 35.12 seconds
2026-01-07 17:18:29.942 | INFO | indextts.infer_vllm_v2:infer:464 - >> RTF: 0.6677
2026-01-07 17:18:29.960 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 17:18:30.355 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 17:18:30.377 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [991824af0b334ef797bfd3a59d5cced6] [prefill time: 0.0148]
2026-01-07 17:18:35.654 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [991824af0b334ef797bfd3a59d5cced6] [decode time: 5.2770] [decode len: 876]
2026-01-07 17:18:42.032 | INFO | indextts.infer_vllm_v2:infer:458 - >> gpt_gen_time: 5.34 seconds
2026-01-07 17:18:42.032 | INFO | indextts.infer_vllm_v2:infer:459 - >> gpt_forward_time: 0.03 seconds
2026-01-07 17:18:42.034 | INFO | indextts.infer_vllm_v2:infer:460 - >> s2mel_time: 5.48 seconds
2026-01-07 17:18:42.034 | INFO | indextts.infer_vllm_v2:infer:461 - >> bigvgan_time: 0.84 seconds
2026-01-07 17:18:42.034 | INFO | indextts.infer_vllm_v2:infer:462 - >> Total inference time: 12.04 seconds
2026-01-07 17:18:42.034 | INFO | indextts.infer_vllm_v2:infer:463 - >> Generated audio length: 17.45 seconds
2026-01-07 17:18:42.035 | INFO | indextts.infer_vllm_v2:infer:464 - >> RTF: 0.6902
2026-01-07 17:18:42.049 | INFO | indextts.infer_vllm_v2:infer:243 - >> start inference...
2026-01-07 17:18:43.081 | INFO | indextts.gpt.model_vllm_v2:inference_speech:222 - Use the specified emotion vector
2026-01-07 17:18:43.099 | INFO | indextts.gpt.model_vllm_v2:inference_speech:248 - [8382f6c343604dcab57888aa85574f6c] [prefill time: 0.0162]
2026-01-07 17:18:54.859 | INFO | indextts.gpt.model_vllm_v2:inference_speech:251 - [8382f6c343604dcab57888aa85574f6c] [decode time: 11.7604] [decode len: 1761]
2026-01-07 17:19:06.392 | INFO | indextts.infer_vllm_v2:infer:458 - >> gpt_gen_time: 11.84 seconds
2026-01-07 17:19:06.393 | INFO | indextts.infer_vllm_v2:infer:459 - >> gpt_forward_time: 0.04 seconds
2026-01-07 17:19:06.393 | INFO | indextts.infer_vllm_v2:infer:460 - >> s2mel_time: 9.74 seconds
2026-01-07 17:19:06.394 | INFO | indextts.infer_vllm_v2:infer:461 - >> bigvgan_time: 0.08 seconds
2026-01-07 17:19:06.395 | INFO | indextts.infer_vllm_v2:infer:462 - >> Total inference time: 24.32 seconds
2026-01-07 17:19:06.395 | INFO | indextts.infer_vllm_v2:infer:463 - >> Generated audio length: 35.12 seconds
2026-01-07 17:19:06.395 | INFO | indextts.infer_vllm_v2:infer:464 - >> RTF: 0.6925
推理生成output.wav
import torch
import time
from typing import Any, List, Optional, Tuple, Union
from packaging import version
import importlib
vllm_version = version.parse(importlib.import_module("vllm").__version__)
# 在 vllm 中注册自定义的 GPT2TTSModel
from vllm import ModelRegistry
from indextts.gpt.index_tts_gpt2_vllm_v1 import GPT2TTSModel
ModelRegistry.register_model("GPT2InferenceModel", GPT2TTSModel)
print("✅ Registry GPT2TTSModel to vllm")
# 将 position_ids 减去 prefill 的长度再加 1,以便正确计算每一步 decode 的 position embedding
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
import numpy as np
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids.cpu[:total_num_scheduled_tokens])
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
# GPT2TTSModel position ids support
model = self.get_model()
if isinstance(model, GPT2TTSModel):
# req_ids_in_batch = self.input_batch.req_ids[:num_reqs]
prompt_tokens_offset = []
for req_id in self.input_batch.req_ids:
prompt_tokens_offset.append(-(len(self.requests[req_id].prompt_token_ids) - 1))
# print(f"[{idx}] self.requests[req_id].prompt_token_ids:", len(self.requests[req_id].prompt_token_ids), positions_np)
np.add(np.array(prompt_tokens_offset)[req_indices],
positions_np,
out=positions_np)
# Prepare the attention metadata.
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
# Copy the tensors to the GPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices)
attn_metadata: dict[str, Any] = {}
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs, 1),
dtype=torch.int32,
device=self.device,
)
slot_mapping = torch.zeros(
(total_num_scheduled_tokens, ),
dtype=torch.int64,
device=self.device,
)
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=encoder_seq_lens,
)
if self.speculative_config and \
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = attn_group.metadata_builder
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks,
kv_cache_group_spec.kv_cache_spec,
builder,
)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder,
GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
GPUModelRunner._prepare_inputs = _prepare_inputs
print("✅ GPUModelRunner._prepare_inputs Patched")
vllm==0.10.2
descript-audiotools==0.7.2
matplotlib==3.8.2
omegaconf
sentencepiece
librosa
gradio
ninja
modelscope
munch==4.0.0
loguru
WeTextProcessing; platform_machine != "Darwin"
wetext; platform_system == "Darwin"
## 基础信息
### 测试对象
IndexTTS-2-vLLM(GPT 部分)
### 设备信息
- RTX 4090
### 环境
- vllm==0.10.2
### 设置
- gpu_memory_utilization = 0.5
- max_num_seqs = 50
- vllm log:(EngineCore_DP0 pid=61946) INFO 10-24 01:36:20 [kv_cache_utils.py:868] Maximum concurrency for 1,818 tokens per request: 51.20x
## 测试策略
- 并发梯度:1 → 4 → 8 → 16 → 32 → 64 用户
- 单用户行为:连续发 10 次请求(run_user_simulation),每次 max_tokens 在 400-1200 随机(≈8-24 s 音频),模拟真实并发请求
- 请求内容:请求内容为服从正态分布的随机向量,避免触发 kv cache
## 测试结果
## Concurrency Level: 1
* **Total Requests:** 10
* **Total Time:** 26.95 s
* **Total Throughput:** 271.99 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 9.18 | 21.22 | 12.38 | 11.42 | 17.90 | 20.55 |
| Latency (ms) | 1448.54 | 4625.49 | 2694.31 | 2148.91 | 4409.40 | 4582.28 |
| Num Generated Tokens | 423 | 1157 | 732.90 | 637 | 1155 | 1157 |
----------------------------------------
## Concurrency Level: 4
* **Total Requests:** 40
* **Total Time:** 30.66 s
* **Total Throughput:** 1007.03 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 8.30 | 41.17 | 14.21 | 11.45 | 38.04 | 40.52 |
| Latency (ms) | 1541.11 | 4761.30 | 2945.13 | 2857.07 | 4398.97 | 4726.94 |
| Num Generated Tokens | 404 | 1179 | 771.77 | 756 | 1157 | 1177 |
----------------------------------------
## Concurrency Level: 8
* **Total Requests:** 80
* **Total Time:** 36.89 s
* **Total Throughput:** 1741.31 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 8.31 | 43.53 | 14.39 | 12.65 | 35.73 | 41.83 |
| Latency (ms) | 1759.68 | 5286.03 | 3336.76 | 3138.05 | 4887.24 | 5250.43 |
| Num Generated Tokens | 430 | 1192 | 802.91 | 774 | 1144 | 1190 |
----------------------------------------
## Concurrency Level: 16
* **Total Requests:** 160
* **Total Time:** 44.64 s
* **Total Throughput:** 2883.50 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 9.40 | 41.57 | 14.47 | 11.44 | 33.03 | 39.86 |
| Latency (ms) | 2043.53 | 6397.30 | 4126.63 | 4217.07 | 6131.03 | 6342.06 |
| Num Generated Tokens | 398 | 1195 | 804.56 | 830 | 1166 | 1183 |
----------------------------------------
## Concurrency Level: 32
* **Total Requests:** 320
* **Total Time:** 62.69 s
* **Total Throughput:** 3998.44 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 10.76 | 83.00 | 20.20 | 15.16 | 61.46 | 78.49 |
| Latency (ms) | 2626.30 | 8564.35 | 5374.98 | 5188.72 | 7984.32 | 8338.44 |
| Num Generated Tokens | 399 | 1196 | 783.34 | 770 | 1144 | 1169 |
----------------------------------------
## Concurrency Level: 64
* **Total Requests:** 640
* **Total Time:** 102.13 s
* **Total Throughput:** 5090.80 tokens/s
| Metric | Min | Max | Mean | P50 | P95 | P99 |
|------------------------|----------|-----------|-----------|-----------|-----------|-----------|
| TTFT (ms) | 10.83 | 5527.29 | 1901.63 | 1962.80 | 3217.76 | 4771.17 |
| Latency (ms) | 3278.74 | 16685.37 | 9343.78 | 9434.96 | 12897.00 | 14347.68 |
| Num Generated Tokens | 398 | 1197 | 812.35 | 828 | 1167 | 1189 |
----------------------------------------
## 分析
- Concurrency 为 64 时触发了并发上限(max_num_seqs=50),因此后来的请求只能排队,若增大 gpu_memory_utilization 和 max_num_seqs 可进一步提高总吞吐量
- 存算皆密集
\ No newline at end of file
import asyncio
import random
import time
import uuid
import torch
import os
from typing import AsyncGenerator, List, Dict, Any
import numpy as np
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # 以下设置可能会轻微影响性能,但确保完全可复现
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
set_seed(42)
import sys
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, root_dir)
import patch_vllm # ⚠️ Monkey Patch, do not delete this line
from indextts.gpt.index_tts_gpt2_vllm_v1 import PLACEHOLDER_TOKEN, PLACEHOLDER_TOKEN_ID
from vllm import SamplingParams, TokensPrompt
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
model_dir = os.path.join(root_dir, "checkpoints/IndexTTS-2-vLLM")
vllm_dir = os.path.join(model_dir, "gpt")
async def run_single_inference(llm: AsyncLLM, inputs_embeds: torch.Tensor):
request_id = uuid.uuid4().hex
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.8,
top_k=30,
repetition_penalty=10.0,
max_tokens=random.randint(400, 1200), # 8s - 24s # 1818
ignore_eos=True,
stop_token_ids=[], # 8193
)
multi_modal_data = {"audio": {"audio_embeds": [inputs_embeds.squeeze(0).cpu()]}}
fake_inputs = PLACEHOLDER_TOKEN * 1
tokens_prompt = TokensPrompt(prompt=fake_inputs, multi_modal_data=multi_modal_data)
start_time = time.time()
output_generator = llm.generate(tokens_prompt, sampling_params, request_id=request_id)
prefill_flag = True
first_token_time = None
ttft = 0.0
output_tokens = []
async for output in output_generator:
if prefill_flag:
first_token_time = time.time()
ttft = first_token_time - start_time
prefill_flag = False
final_output = output
end_time = time.time()
latency = end_time - start_time
decode_time = end_time - first_token_time
generated_tokens = final_output.outputs[0].token_ids[:-2]
num_generated_tokens = len(generated_tokens)
decode_speed = 0
if decode_time > 0:
decode_speed = num_generated_tokens / decode_time
return {
"prefill_time": ttft,
"decode_time": decode_time,
"latency": latency,
"num_generated_tokens": num_generated_tokens,
"decode_speed_tokens_per_sec": decode_speed,
}
async def run_user_simulation(llm: AsyncLLM, inputs_embeds: torch.Tensor, num_runs: int = 10) -> List[Dict[str, Any]]:
user_results = []
for _ in range(num_runs):
result = await run_single_inference(llm, inputs_embeds)
user_results.append(result)
return user_results
async def benchmark(llm_engine, concurrency_levels, runs_per_user):
hidden_dim = 1280
conds_len = 34
max_text_tokens = 200
# warm up
fake_inputs_embeds = torch.randn(
1,
conds_len + max_text_tokens,
hidden_dim,
dtype=torch.float16,
device="cpu"
)
await run_single_inference(llm_engine, fake_inputs_embeds)
for concurrency in concurrency_levels:
# 为每个并发用户准备独立的输入数据
user_inputs = [torch.randn(
1,
conds_len + max_text_tokens,
hidden_dim,
dtype=torch.float16,
device="cpu"
) for _ in range(concurrency)]
# 创建并发任务,每个任务代表一个连续请求10次的用户
tasks = [
run_user_simulation(llm_engine, user_inputs[c_idx], runs_per_user)
for c_idx in range(concurrency)
]
benchmark_start_time = time.time()
results = await asyncio.gather(*tasks)
benchmark_total_time = time.time() - benchmark_start_time
# --- 统计和计算结果 ---
all_ttfts_ms = []
all_latencies_ms = []
all_num_generated_tokens = []
total_requests = 0
total_generated_tokens = 0
for user_results in results:
for res in user_results:
total_requests += 1
total_generated_tokens += res["num_generated_tokens"]
all_ttfts_ms.append(res["prefill_time"] * 1000)
all_latencies_ms.append(res["latency"] * 1000)
all_num_generated_tokens.append(res["num_generated_tokens"])
# Convert to numpy arrays for statistics
ttfts_np = np.array(all_ttfts_ms)
latencies_np = np.array(all_latencies_ms)
tokens_np = np.array(all_num_generated_tokens)
# 总吞吐量 = 在总测试时间内生成的总token数
total_throughput = total_generated_tokens / benchmark_total_time
print(f"## Concurrency Level: {concurrency}")
print(f"\n* **Total Requests:** {concurrency * runs_per_user}")
print(f"* **Total Time:** {benchmark_total_time:.2f} s")
print(f"* **Total Throughput:** {total_throughput:.2f} tokens/s\n")
print("| Metric | Min | Max | Mean | P50 | P95 | P99 |")
print("|------------------------|----------|-----------|-----------|-----------|-----------|-----------|")
if total_requests > 0:
# TTFT
print(f"| TTFT (ms) | {np.min(ttfts_np):<8.2f} | {np.max(ttfts_np):<9.2f} | {np.mean(ttfts_np):<9.2f} | {np.percentile(ttfts_np, 50):<9.2f} | {np.percentile(ttfts_np, 95):<9.2f} | {np.percentile(ttfts_np, 99):<9.2f} |")
# Latency
print(f"| Latency (ms) | {np.min(latencies_np):<8.2f} | {np.max(latencies_np):<9.2f} | {np.mean(latencies_np):<9.2f} | {np.percentile(latencies_np, 50):<9.2f} | {np.percentile(latencies_np, 95):<9.2f} | {np.percentile(latencies_np, 99):<9.2f} |")
# Num Generated Tokens
print(f"| Num Generated Tokens | {np.min(tokens_np):<8.0f} | {np.max(tokens_np):<9.0f} | {np.mean(tokens_np):<9.2f} | {np.percentile(tokens_np, 50):<9.0f} | {np.percentile(tokens_np, 95):<9.0f} | {np.percentile(tokens_np, 99):<9.0f} |")
print("\n" + "-"*40 + "\n")
if __name__ == "__main__":
gpu_memory_utilization = 0.5 # 0.25
concurrency_levels = [1, 4, 8, 16, 32, 64] # 并发数
runs_per_user = 10 # 每个并发的请求数
engine_args = AsyncEngineArgs(
model=vllm_dir,
tensor_parallel_size=1,
dtype="auto",
gpu_memory_utilization=gpu_memory_utilization,
max_num_seqs=50
)
llm_engine: AsyncLLM = AsyncLLM.from_engine_args(engine_args)
asyncio.run(benchmark(llm_engine, concurrency_levels, runs_per_user))
\ No newline at end of file
import argparse
import threading
import time
import requests
from collections import defaultdict
import random
class TTSStressTester:
def __init__(self, urls, data, concurrency, requests_per_thread):
self.urls = urls
self.data = data
self.concurrency = concurrency
self.requests_per_thread = requests_per_thread
self.stats = {
'total': 0,
'success': 0,
'fail': 0,
'durations': [],
'status_codes': defaultdict(int),
'errors': defaultdict(int)
}
self.lock = threading.Lock()
self.current_url_index = 0
self.url_lock = threading.Lock() # 用于轮询URL的锁
def _get_next_url(self):
with self.url_lock:
url = self.urls[self.current_url_index]
self.current_url_index = (self.current_url_index + 1) % len(self.urls)
return url
def _send_request(self):
start_time = time.time()
try:
# 生成随机数字符串,确保不触发 vllm 的 cache
self.data["text"] = ",".join(["".join([str(random.randint(0, 9)) for _ in range(5)]) for _ in range(5)])
target_url = self._get_next_url() # 获取轮询后的URL
response = requests.post(target_url, json=self.data, timeout=10)
elapsed = time.time() - start_time
with self.lock:
self.stats['durations'].append(elapsed)
self.stats['status_codes'][response.status_code] += 1
self.stats['total'] += 1
if response.status_code == 200:
content_type = response.headers.get('Content-Type', '')
if 'audio' in content_type:
self.stats['success'] += 1
else:
self.stats['fail'] += 1
self.stats['errors']['invalid_content_type'] += 1
else:
self.stats['fail'] += 1
except Exception as e:
with self.lock:
self.stats['fail'] += 1
self.stats['errors'][str(type(e).__name__)] += 1
self.stats['durations'].append(time.time() - start_time)
def _worker(self):
for _ in range(self.requests_per_thread):
self._send_request()
def run(self):
threads = []
start_time = time.time()
for _ in range(self.concurrency):
thread = threading.Thread(target=self._worker)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
total_time = time.time() - start_time
self._generate_report(total_time)
def _generate_report(self, total_time):
durations = self.stats['durations']
total_requests = self.stats['total']
print(f"\n{' 测试报告 ':=^40}")
print(f"总请求时间: {total_time:.2f}秒")
print(f"总请求量: {total_requests}")
print(f"成功请求: {self.stats['success']}")
print(f"失败请求: {self.stats['fail']}")
if durations:
avg_duration = sum(durations) / len(durations)
max_duration = max(durations)
min_duration = min(durations)
print(f"\n响应时间统计:")
print(f"平均: {avg_duration:.3f}秒")
print(f"最大: {max_duration:.3f}秒")
print(f"最小: {min_duration:.3f}秒")
sorted_durations = sorted(durations)
for p in [50, 90, 95, 99]:
index = int(p / 100 * len(sorted_durations))
print(f"P{p}: {sorted_durations[index]:.3f}秒")
print("\n状态码分布:")
for code, count in self.stats['status_codes'].items():
print(f"HTTP {code}: {count}次")
if self.stats['errors']:
print("\n错误统计:")
for error, count in self.stats['errors'].items():
print(f"{error}: {count}次")
print(f"\n吞吐量: {total_requests / total_time:.2f} 请求/秒")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='TTS服务压力测试脚本')
parser.add_argument('--urls', nargs='+',
default=['http://localhost:6006/tts'],
help='TTS服务地址列表(多个用空格分隔)')
parser.add_argument('--text', type=str, default='测试文本', help='需要合成的文本内容')
parser.add_argument('--character', type=str, default='jay_klee', help='合成角色名称')
parser.add_argument('--concurrency', type=int, default=16, help='并发线程数')
parser.add_argument('--requests', type=int, default=5, help='每个线程的请求数')
args = parser.parse_args()
test_data = {
"text": args.text,
"character": args.character
}
tester = TTSStressTester(
urls=args.urls,
data=test_data,
concurrency=args.concurrency,
requests_per_thread=args.requests
)
print(f"开始压力测试,配置参数:")
print(f"目标服务: {', '.join(args.urls)}")
print(f"并发线程: {args.concurrency}")
print(f"单线程请求数: {args.requests}")
print(f"总预计请求量: {args.concurrency * args.requests}")
print(f"{' 测试启动 ':=^40}")
try:
tester.run()
except KeyboardInterrupt:
print("\n测试被用户中断")
\ No newline at end of file
import json
import locale
import os
I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale')
def load_language_list(language):
with open(os.path.join(I18N_JSON_DIR, f"{language}.json"), "r", encoding="utf-8") as f:
language_list = json.load(f)
return language_list
def scan_language_list():
language_list = []
for name in os.listdir(I18N_JSON_DIR):
if name.endswith(".json"):language_list.append(name.split('.')[0])
return language_list
class I18nAuto:
def __init__(self, language=None):
if language in ["Auto", None]:
language = locale.getdefaultlocale()[0]
# getlocale can't identify the system's language ((None, None))
if not os.path.exists(os.path.join(I18N_JSON_DIR, f"{language}.json")):
language = "en_US"
self.language = language
self.language_map = load_language_list(language)
def __call__(self, key):
return self.language_map.get(key, key)
def __repr__(self):
return "Use Language: " + self.language
if __name__ == "__main__":
i18n = I18nAuto(language='en_US')
print(i18n)
\ No newline at end of file
{
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under the MIT License. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.",
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE file in the root directory."
}
\ No newline at end of file
import ast
import glob
import json
import os
from collections import OrderedDict
I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale')
DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言
TITLE_LEN : int = 60 # 标题显示长度
KEY_LEN : int = 30 # 键名显示长度
SHOW_KEYS : bool = False # 是否显示键信息
SORT_KEYS : bool = False # 是否按全局键名写入文件
def extract_i18n_strings(node):
i18n_strings = []
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == "i18n"
):
for arg in node.args:
if isinstance(arg, ast.Str):
i18n_strings.append(arg.s)
for child_node in ast.iter_child_nodes(node):
i18n_strings.extend(extract_i18n_strings(child_node))
return i18n_strings
def scan_i18n_strings():
"""
scan the directory for all .py files (recursively)
for each file, parse the code into an AST
for each AST, extract the i18n strings
"""
strings = []
print(" Scanning Files and Extracting i18n Strings ".center(TITLE_LEN, "="))
for filename in glob.iglob("**/*.py", recursive=True):
try:
with open(filename, "r", encoding="utf-8") as f:
code = f.read()
if "I18nAuto" in code:
tree = ast.parse(code)
i18n_strings = extract_i18n_strings(tree)
print(f"{filename.ljust(KEY_LEN*3//2)}: {len(i18n_strings)}")
if SHOW_KEYS:
print("\n".join([s for s in i18n_strings]))
strings.extend(i18n_strings)
except Exception as e:
print(f"\033[31m[Failed] Error occur at {filename}: {e}\033[0m")
code_keys = set(strings)
print(f"{'Total Unique'.ljust(KEY_LEN*3//2)}: {len(code_keys)}")
return code_keys
def update_i18n_json(json_file, standard_keys):
standard_keys = sorted(standard_keys)
print(f" Process {json_file} ".center(TITLE_LEN, "="))
# 读取 JSON 文件
with open(json_file, "r", encoding="utf-8") as f:
json_data = json.load(f, object_pairs_hook=OrderedDict)
# 打印处理前的 JSON 条目数
len_before = len(json_data)
print(f"{'Total Keys'.ljust(KEY_LEN)}: {len_before}")
# 识别缺失的键并补全
miss_keys = set(standard_keys) - set(json_data.keys())
if len(miss_keys) > 0:
print(f"{'Missing Keys (+)'.ljust(KEY_LEN)}: {len(miss_keys)}")
for key in miss_keys:
if DEFAULT_LANGUAGE in json_file:
# 默认语言的键值相同.
json_data[key] = key
else:
# 其他语言的值设置为 #! + 键名以标注未被翻译.
json_data[key] = "#!" + key
if SHOW_KEYS:
print(f"{'Added Missing Key'.ljust(KEY_LEN)}: {key}")
# 识别多余的键并删除
diff_keys = set(json_data.keys()) - set(standard_keys)
if len(diff_keys) > 0:
print(f"{'Unused Keys (-)'.ljust(KEY_LEN)}: {len(diff_keys)}")
for key in diff_keys:
del json_data[key]
if SHOW_KEYS:
print(f"{'Removed Unused Key'.ljust(KEY_LEN)}: {key}")
# 按键顺序排序
json_data = OrderedDict(
sorted(
json_data.items(),
key=lambda x: (
list(standard_keys).index(x[0]) if x[0] in standard_keys and not x[1].startswith('#!') else len(json_data),
)
)
)
# 打印处理后的 JSON 条目数
if len(miss_keys) != 0 or len(diff_keys) != 0:
print(f"{'Total Keys (After)'.ljust(KEY_LEN)}: {len(json_data)}")
# 识别有待翻译的键
num_miss_translation = 0
duplicate_items = {}
for key, value in json_data.items():
if value.startswith("#!"):
num_miss_translation += 1
if SHOW_KEYS:
print(f"{'Missing Translation'.ljust(KEY_LEN)}: {key}")
if value in duplicate_items:
duplicate_items[value].append(key)
else:
duplicate_items[value] = [key]
# 打印是否有重复的值
for value, keys in duplicate_items.items():
if len(keys) > 1:
print("\n".join([f"\033[31m{'[Failed] Duplicate Value'.ljust(KEY_LEN)}: {key} -> {value}\033[0m" for key in keys]))
if num_miss_translation > 0:
print(f"\033[31m{'[Failed] Missing Translation'.ljust(KEY_LEN)}: {num_miss_translation}\033[0m")
else:
print(f"\033[32m[Passed] All Keys Translated\033[0m")
# 将处理后的结果写入 JSON 文件
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=SORT_KEYS)
f.write("\n")
print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n')
if __name__ == "__main__":
code_keys = scan_i18n_strings()
for json_file in os.listdir(I18N_JSON_DIR):
if json_file.endswith(r".json"):
json_file = os.path.join(I18N_JSON_DIR, json_file)
update_i18n_json(json_file, code_keys)
\ No newline at end of file
import os
import sys
import threading
import time
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import gradio as gr
from indextts.infer_vllm import IndexTTS
import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
parser.add_argument("--port", type=int, default=6006, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
parser.add_argument("--version", type=str, default="1.0", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="", help="Model checkpoints directory")
parser.add_argument("--gpu_memory_utilization", type=float, default=0.25, help="Port to run the web UI on")
cmd_args = parser.parse_args()
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
model_dir = None
if cmd_args.model_dir:
model_dir = cmd_args.model_dir
else:
if cmd_args.version == "1.0":
model_dir = os.path.join(CURRENT_DIR, "checkpoints/Index-TTS-vLLM")
elif cmd_args.version == "1.5":
model_dir = os.path.join(CURRENT_DIR, "checkpoints/Index-TTS-1.5-vLLM")
async def gen_single(prompts, text, progress=gr.Progress()):
output_path = None
tts.gr_progress = progress
if isinstance(prompts, list):
prompt_paths = [prompt.name for prompt in prompts if prompt is not None]
else:
prompt_paths = [prompts.name] if prompts is not None else []
output = await tts.infer(prompt_paths, text, output_path, verbose=True)
return gr.update(value=output, visible=True)
def update_prompt_audio():
return gr.update(interactive=True)
if __name__ == "__main__":
tts = IndexTTS(model_dir=model_dir, gpu_memory_utilization=cmd_args.gpu_memory_utilization)
with gr.Blocks() as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
<p align="center">
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
''')
with gr.Tab("音频生成"):
with gr.Row():
# 使用 gr.File 替代 gr.Audio 来支持多文件上传
prompt_audio = gr.File(
label="请上传参考音频(可上传多个)",
file_count="multiple",
file_types=["audio"]
)
with gr.Column():
input_text_single = gr.TextArea(label="请输入目标文本", key="input_text_single")
gen_button = gr.Button("生成语音", key="gen_button", interactive=True)
output_audio = gr.Audio(label="生成结果", visible=True, key="output_audio")
prompt_audio.upload(
update_prompt_audio,
inputs=[],
outputs=[gen_button]
)
gen_button.click(
gen_single,
inputs=[prompt_audio, input_text_single],
outputs=[output_audio]
)
demo.queue(20)
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)
\ No newline at end of file
import json
import logging
import os
import sys
import threading
import time
import warnings
import pandas as pd
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# current_dir = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(current_dir)
# sys.path.append(os.path.join(current_dir, "indextts"))
import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
parser.add_argument("--port", type=int, default=6006, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="checkpoints/IndexTTS-2-vLLM", help="Model checkpoints directory")
parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer")
parser.add_argument("--gpu_memory_utilization", type=float, default=0.25)
parser.add_argument("--qwenemo_gpu_memory_utilization", type=float, default=0.10)
cmd_args = parser.parse_args()
if not os.path.exists(cmd_args.model_dir):
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
sys.exit(1)
for file in [
"bpe.model",
"gpt.pth",
"config.yaml",
"s2mel.pth",
"wav2vec2bert_stats.pt"
]:
file_path = os.path.join(cmd_args.model_dir, file)
if not os.path.exists(file_path):
print(f"Required file {file_path} does not exist. Please download it.")
sys.exit(1)
import gradio as gr
# from indextts import infer
from indextts.infer_vllm_v2 import IndexTTS2
from tools.i18n.i18n import I18nAuto
from modelscope.hub import api
i18n = I18nAuto(language="Auto")
MODE = 'local'
# 支持的语言列表
LANGUAGES = {
"中文": "zh_CN",
"English": "en_US"
}
EMO_CHOICES = [i18n("与音色参考音频相同"),
i18n("使用情感参考音频"),
i18n("使用情感向量控制"),
i18n("使用情感描述文本控制")]
os.makedirs("outputs/tasks",exist_ok=True)
os.makedirs("prompts",exist_ok=True)
MAX_LENGTH_TO_USE_SPEED = 70
with open("examples/cases.jsonl", "r", encoding="utf-8") as f:
example_cases = []
for line in f:
line = line.strip()
if not line:
continue
example = json.loads(line)
if example.get("emo_audio",None):
emo_audio_path = os.path.join("examples",example["emo_audio"])
else:
emo_audio_path = None
example_cases.append([os.path.join("examples", example.get("prompt_audio", "sample_prompt.wav")),
EMO_CHOICES[example.get("emo_mode",0)],
example.get("text"),
emo_audio_path,
example.get("emo_weight",1.0),
example.get("emo_text",""),
example.get("emo_vec_1",0),
example.get("emo_vec_2",0),
example.get("emo_vec_3",0),
example.get("emo_vec_4",0),
example.get("emo_vec_5",0),
example.get("emo_vec_6",0),
example.get("emo_vec_7",0),
example.get("emo_vec_8",0)]
)
async def gen_single(emo_control_method,prompt, text,
emo_ref_path, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
emo_text,emo_random,
max_text_tokens_per_sentence=120,
*args, progress=gr.Progress()):
output_path = None
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
# set gradio progress
# tts.gr_progress = progress
do_sample, top_p, top_k, temperature, \
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
kwargs = {
"do_sample": bool(do_sample),
"top_p": float(top_p),
"top_k": int(top_k) if int(top_k) > 0 else None,
"temperature": float(temperature),
"length_penalty": float(length_penalty),
"num_beams": num_beams,
"repetition_penalty": float(repetition_penalty),
"max_mel_tokens": int(max_mel_tokens),
# "typical_sampling": bool(typical_sampling),
# "typical_mass": float(typical_mass),
}
if type(emo_control_method) is not int:
emo_control_method = emo_control_method.value
if emo_control_method == 0:
emo_ref_path = None
emo_weight = 1.0
if emo_control_method == 1:
emo_weight = emo_weight
if emo_control_method == 2:
vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]
vec_sum = sum([vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8])
if vec_sum > 1.5:
gr.Warning(i18n("情感向量之和不能超过1.5,请调整后重试。"))
return
else:
vec = None
print(f"Emo control mode:{emo_control_method},vec:{vec}")
output = await tts.infer(spk_audio_prompt=prompt, text=text,
output_path=output_path,
emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight,
emo_vector=vec,
use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
**kwargs)
return gr.update(value=output,visible=True)
def update_prompt_audio():
update_button = gr.update(interactive=True)
return update_button
if __name__ == "__main__":
tts = IndexTTS2(
model_dir=cmd_args.model_dir,
is_fp16=cmd_args.is_fp16,
gpu_memory_utilization=cmd_args.gpu_memory_utilization,
qwenemo_gpu_memory_utilization=cmd_args.qwenemo_gpu_memory_utilization,
)
with gr.Blocks(title="IndexTTS Demo") as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech</h2>
<p align="center">
<a href='https://arxiv.org/abs/2506.21619'><img src='https://img.shields.io/badge/ArXiv-2506.21619-red'></a>
</p>
''')
with gr.Tab(i18n("音频生成")):
with gr.Row():
os.makedirs("prompts",exist_ok=True)
prompt_audio = gr.Audio(label=i18n("音色参考音频"),key="prompt_audio",
sources=["upload","microphone"],type="filepath")
prompt_list = os.listdir("prompts")
default = ''
if prompt_list:
default = prompt_list[0]
with gr.Column():
input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{'2.0' or '1.0'}")
gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
with gr.Accordion(i18n("功能设置")):
# 情感控制选项部分
with gr.Row():
emo_control_method = gr.Radio(
choices=EMO_CHOICES,
type="index",
value=EMO_CHOICES[0],label=i18n("情感控制方式"))
# 情感参考音频部分
with gr.Group(visible=False) as emotion_reference_group:
with gr.Row():
emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath")
with gr.Row():
emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.6, value=0.8, step=0.01)
# 情感随机采样
with gr.Row():
emo_random = gr.Checkbox(label=i18n("情感随机采样"),value=False,visible=False)
# 情感向量控制部分
with gr.Group(visible=False) as emotion_vector_group:
with gr.Row():
with gr.Column():
vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
with gr.Column():
vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
with gr.Group(visible=False) as emo_text_group:
with gr.Row():
emo_text = gr.Textbox(label=i18n("情感描述文本"), placeholder=i18n("请输入情感描述文本"), value="", info=i18n("例如:高兴,愤怒,悲伤等"))
with gr.Accordion(i18n("高级生成参数设置"), open=False):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')}[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
with gr.Row():
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1)
with gr.Row():
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
with gr.Row():
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
# with gr.Row():
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
with gr.Column(scale=2):
gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
with gr.Row():
max_text_tokens_per_sentence = gr.Slider(
label=i18n("分句最大Token数"), value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"),
)
with gr.Accordion(i18n("预览分句结果"), open=True) as sentences_settings:
sentences_preview = gr.Dataframe(
headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
key="sentences_preview",
wrap=True,
)
advanced_params = [
do_sample, top_p, top_k, temperature,
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
# typical_sampling, typical_mass,
]
if len(example_cases) > 0:
gr.Examples(
examples=example_cases,
examples_per_page=20,
inputs=[prompt_audio,
emo_control_method,
input_text_single,
emo_upload,
emo_weight,
emo_text,
vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8]
)
def on_input_text_change(text, max_tokens_per_sentence):
if text and len(text) > 0:
text_tokens_list = tts.tokenizer.tokenize(text)
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
data = []
for i, s in enumerate(sentences):
sentence_str = ''.join(s)
tokens_count = len(s)
data.append([i, sentence_str, tokens_count])
return {
sentences_preview: gr.update(value=data, visible=True, type="array"),
}
else:
df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
return {
sentences_preview: gr.update(value=df),
}
def on_method_select(emo_control_method):
if emo_control_method == 1:
return (gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
elif emo_control_method == 2:
return (gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False)
)
elif emo_control_method == 3:
return (gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True)
)
else:
return (gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
emo_control_method.select(on_method_select,
inputs=[emo_control_method],
outputs=[emotion_reference_group,
emo_random,
emotion_vector_group,
emo_text_group]
)
input_text_single.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
max_text_tokens_per_sentence.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
emo_text,emo_random,
max_text_tokens_per_sentence,
*advanced_params,
],
outputs=[output_audio])
demo.queue(20)
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment