Commit e6e33f1a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2698 canceled with stages
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from typing import Optional, Tuple
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F
DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(input, self.padding_idx) if positions is None else positions
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
class RotaryEmbeddings(nn.Module):
cos: torch.Tensor
sin: torch.Tensor
theta: torch.Tensor
def __init__(
self,
width: int,
*,
seq_len: int = 40000,
base: int = 10000,
device: Optional[torch.device] = None,
):
"""Rotary embeddings (Su et al., 2021) layer. The rotary embedding
will be precomputed for up to 'seq _len' positions. The embedding
will be recomputed when a longer sequence is found in the input.
:param width:
Rotary embedding dimensionality, must be even.
:param seq_len:
Number of positons to initially precompute.
:param base:
The base used for Θ_i, determines the cycle length of the
embeddings.
:param device: Device on which the module is to be initialized.
"""
super().__init__()
if width % 2:
raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
# Ignore allocations on the meta device as we don't persist our buffer,
# i.e., we don't expect the backing tensor to be replaced with pretrained weights.
if device is not None and device.type == "meta":
device = None
# Θ_i = 10000^(-2(i-1)/d)
theta = torch.pow(
base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
)
self.register_buffer("theta", theta, persistent=False)
self._create_rotary_embed(width=width, length=seq_len)
def _create_rotary_embed(self, *, width: int, length: int):
# mΘ
position = torch.arange(length, device=self.theta.device).unsqueeze(1)
m_theta = position * self.theta.unsqueeze(0)
# We apply both sin and cos twice (see Eq 15, 34), but the ordering
# is changed for compatibility with most common implementations.
m_theta = torch.cat([m_theta, m_theta], dim=-1)
re_cos = m_theta.cos().view([length, width])
re_sin = m_theta.sin().view([length, width])
self.register_buffer("cos", re_cos, persistent=False)
self.register_buffer("sin", re_sin, persistent=False)
def _rotate(self, input: torch.Tensor):
"""Rotate the input tensor by half of its innermost width.
input (Tensor): array to rotate.
RETURNS (Tensor): rotated array.
Shapes:
input - (..., width)
output - (..., width)
"""
half_idx = input.shape[-1] // 2
input_1 = -input[..., half_idx:]
input_2 = input[..., :half_idx]
return torch.cat([input_1, input_2], dim=-1)
def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
"""
Apply rotary embeddings to an array.
:param input: Array to apply the rotary embeddings to.
:param positions: positions of the inputs. If no positions are
provided, they are assumed to be [0, seq_len).
:return: Array with the rotary embeddings applied.
Shapes:
input - (batch_size, num_heads, seq_len, width_per_head)
positions - (batch_size, seq_len)
output - (batch_size, num_heads, seq_len, width_per_head)
"""
batch_size, _, seq_len, width = input.shape
if positions is None:
# Fastpath: positions from [0..seq_len), avoid indexing.
if self.cos.size(-2) < seq_len:
self._create_rotary_embed(width=width, length=seq_len)
rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
else:
max_len = int(positions.max()) + 1
if self.cos.size(-2) < max_len:
self._create_rotary_embed(width=width, length=max_len)
# Flatten positions to index cos/sin arrays, then unflatten.
#
# Example shapes:
#
# positions_flat - (batch_size * seq_len)
# self.cos - (max_len, width)
# rot_cos - (batch_size, seq_len, width)
positions_flat = positions.view(-1)
rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
# Eq 34 with ordering changed for compatibility.
return rot_cos * input + rot_sin * self._rotate(input)
class RotMultiheadAttention(MultiheadAttention):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention)
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
def forward(
self,
query, key, value,
spk_pos_ids_flat=None,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = q * self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
# Apply rot embedding and store incremental_state
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
if incremental_state is not None:
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
else:
key_pos = spk_pos_ids_flat
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights + attn_mask
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
if reset_attn_weight is not None:
if reset_attn_weight:
self.last_attn_probs = attn_probs.detach()
else:
assert self.last_attn_probs is not None
attn_probs = self.last_attn_probs
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None
return attn, (attn_weights, attn_logits)
class RotMultiheadAttention2(MultiheadAttention):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention)
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
def forward(
self,
query, key, value,
spk_pos_ids_flat=None,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
# Apply rot embedding and store incremental_state
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_logits = None
attn_weights = None
return attn, (attn_weights, attn_logits)
class RotDecSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
super().__init__()
self.c = c
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
self.self_attn = RotMultiheadAttention(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
dropout=relu_dropout, act=act, bias=bias)
self.post_ln = post_ln
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
attn_out=None,
reset_attn_weight=None,
spk_pos_ids_flat=None,
**kwargs,
):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
residual = x
if not self.post_ln:
x = self.layer_norm1(x)
x, (attn_weights, _) = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
attn_mask=self_attn_mask,
spk_pos_ids_flat=spk_pos_ids_flat
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm1(x)
residual = x
if not self.post_ln:
x = self.layer_norm2(x)
x = self.ffn(x, incremental_state=incremental_state)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm2(x)
return x, attn_weights
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
self.encoder_attn.clear_buffer(incremental_state)
self.ffn.clear_buffer(incremental_state)
def set_buffer(self, name, tensor, incremental_state):
return set_incremental_state(self, incremental_state, name, tensor)
class RotDecSALayer2(RotDecSALayer):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
ffn_hidden_size=1024, act='gelu', post_ln=False):
super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
post_ln)
self.self_attn = RotMultiheadAttention2(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
class RotTransformerDecoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
op_version=1, bias=True):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
if op_version == 1:
self.op = RotDecSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln, bias=bias)
else:
self.op = RotDecSALayer2(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
def clear_buffer(self, *args):
return self.op.clear_buffer(*args)
def set_buffer(self, *args):
return self.op.set_buffer(*args)
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import torch
import torch.nn.functional as F
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype)
return mask
def weights_nonzero_speech(target):
# target : B x T x mel
# Assign weight 1.0 to all labels except for padding (id=0).
dim = target.size(-1)
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def select_attn(attn_logits, type='best'):
"""
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
:return:
"""
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
# [n_layers * n_head, B, T_sp, T_txt]
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
if type == 'best':
indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
encdec_attn = encdec_attn.gather(
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
return encdec_attn
elif type == 'mean':
return encdec_attn.mean(0)
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
def group_hidden_by_segs(h, seg_ids, max_len):
"""
:param h: [B, T, H]
:param seg_ids: [B, T]
:return: h_ph: [B, T_ph, H]
"""
B, T, H = h.shape
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
all_ones = h.new_ones(h.shape[:2])
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
h_gby_segs = h_gby_segs[:, 1:]
cnt_gby_segs = cnt_gby_segs[:, 1:]
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
return h_gby_segs, cnt_gby_segs
def expand_by_repeat_times(source_encoding, lengths):
"""
source_encoding: [T, C]
lengths, list of int, [T,], how many times each token should repeat
return:
expanded_encoding: [T_expand, C]
"""
hid_dim = source_encoding.shape[1]
out2source = []
for i, length in enumerate(lengths):
out2source += [i for _ in range(length)]
out2source = torch.LongTensor(out2source).to(source_encoding.device)
out2source_ = out2source[:, None].repeat([1, hid_dim])
expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
return expanded_encoding
def expand_word2ph(word_encoding, ph2word):
word_encoding = F.pad(word_encoding,[0,0,1,0])
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
return out
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F
DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(input, self.padding_idx) if positions is None else positions
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
class TransformerFFNLayer(nn.Module):
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
super().__init__()
self.kernel_size = kernel_size
self.dropout = dropout
self.act = act
if padding == 'SAME':
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
padding=kernel_size // 2, bias=bias)
elif padding == 'LEFT':
self.ffn_1 = nn.Sequential(
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
)
self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
def forward(self, x, incremental_state=None):
# x: T x B x C
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_input' in saved_state:
prev_input = saved_state['prev_input']
x = torch.cat((prev_input, x), dim=0)
x = x[-self.kernel_size:]
saved_state['prev_input'] = x
self._set_input_buffer(incremental_state, saved_state)
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
x = x * self.kernel_size ** -0.5
if incremental_state is not None:
x = x[-1:]
if self.act == 'gelu':
x = F.gelu(x)
if self.act == 'relu':
x = F.relu(x)
x = F.dropout(x, self.dropout, training=self.training)
x = self.ffn_2(x)
return x
def _get_input_buffer(self, incremental_state):
return get_incremental_state(
self,
incremental_state,
'f',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
set_incremental_state(
self,
incremental_state,
'f',
buffer,
)
def clear_buffer(self, incremental_state):
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_input' in saved_state:
del saved_state['prev_input']
self._set_input_buffer(incremental_state, saved_state)
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.enable_torch_version = False
self.last_attn_probs = None
def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query, key, value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
if self.qkv_same_dim:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
self.in_proj_weight,
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask)
else:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = q * self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights + attn_mask
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
if reset_attn_weight is not None:
if reset_attn_weight:
self.last_attn_probs = attn_probs.detach()
else:
assert self.last_attn_probs is not None
attn_probs = self.last_attn_probs
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None
return attn, (attn_weights, attn_logits)
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_q(self, query):
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
def _get_input_buffer(self, incremental_state):
return get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights
def clear_buffer(self, incremental_state=None):
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
del saved_state['prev_key']
if 'prev_value' in saved_state:
del saved_state['prev_value']
self._set_input_buffer(incremental_state, saved_state)
class EncSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
ffn_hidden_size=1024):
super().__init__()
self.c = c
self.dropout = dropout
self.num_heads = num_heads
if num_heads > 0:
self.layer_norm1 = LayerNorm(c)
self.self_attn = MultiheadAttention(
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
def forward(self, x, encoder_padding_mask=None, **kwargs):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
if self.num_heads > 0:
residual = x
x = self.layer_norm1(x)
x, _, = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
residual = x
x = self.layer_norm2(x)
x = self.ffn(x)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
return x
class DecSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
super().__init__()
self.c = c
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
self.self_attn = MultiheadAttention(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
self.layer_norm2 = LayerNorm(c)
self.encoder_attn = MultiheadAttention(
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
)
self.layer_norm3 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
self.post_ln = post_ln
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
attn_out=None,
reset_attn_weight=None,
**kwargs,
):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
self.layer_norm3.training = layer_norm_training
residual = x
if not self.post_ln:
x = self.layer_norm1(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
attn_mask=self_attn_mask
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm1(x)
attn_logits = None
if encoder_out is not None or attn_out is not None:
residual = x
if not self.post_ln:
x = self.layer_norm2(x)
if encoder_out is not None:
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
'enc_dec_attn_constraint_mask'),
reset_attn_weight=reset_attn_weight
)
attn_logits = attn[1]
elif attn_out is not None:
x = self.encoder_attn.in_proj_v(attn_out)
if encoder_out is not None or attn_out is not None:
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm2(x)
residual = x
if not self.post_ln:
x = self.layer_norm3(x)
x = self.ffn(x, incremental_state=incremental_state)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm3(x)
return x, attn_logits
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
self.encoder_attn.clear_buffer(incremental_state)
self.ffn.clear_buffer(incremental_state)
def set_buffer(self, name, tensor, incremental_state):
return set_incremental_state(self, incremental_state, name, tensor)
class TransformerEncoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
self.op = EncSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
class TransformerDecoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
self.op = DecSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
def clear_buffer(self, *args):
return self.op.clear_buffer(*args)
def set_buffer(self, *args):
return self.op.set_buffer(*args)
class FFTBlocks(nn.Module):
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
num_heads=2, use_pos_embed=True, use_last_norm=True,
use_pos_embed_alpha=True, ffn_hidden_size=1024):
super().__init__()
self.num_layers = num_layers
embed_dim = self.hidden_size = hidden_size
self.dropout = dropout
self.use_pos_embed = use_pos_embed
self.use_last_norm = use_last_norm
if use_pos_embed:
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
self.padding_idx = 0
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
self.embed_positions = SinusoidalPositionalEmbedding(
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(self.hidden_size, self.dropout,
kernel_size=ffn_kernel_size, num_heads=num_heads,
ffn_hidden_size=ffn_hidden_size)
for _ in range(self.num_layers)
])
if self.use_last_norm:
self.layer_norm = nn.LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
"""
:param x: [B, T, C]
:param padding_mask: [B, T]
:return: [B, T, C] or [L, B, T, C]
"""
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
if self.use_pos_embed:
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
x = x + positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1) * nonpadding_mask_TB
hiddens = []
for layer in self.layers:
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
hiddens.append(x)
if self.use_last_norm:
x = self.layer_norm(x) * nonpadding_mask_TB
if return_hiddens:
x = torch.stack(hiddens, 0) # [L, T, B, C]
x = x.transpose(1, 2) # [L, B, T, C]
else:
x = x.transpose(0, 1) # [B, T, C]
return x
class FastSpeechEncoder(FFTBlocks):
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
dropout=0.0, num_heads=2, ffn_hidden_size=1024):
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
self.embed_scale = math.sqrt(hidden_size)
self.padding_idx = 0
self.embed_positions = SinusoidalPositionalEmbedding(
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
)
def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
"""
:param txt_tokens: [B, T]
:return: {
'encoder_out': [B x T x C]
}
"""
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
if self.num_layers > 0:
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
return x
def forward_embedding(self, txt_tokens):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(txt_tokens)
if self.use_pos_embed:
positions = self.embed_positions(txt_tokens)
x = x + positions
x = F.dropout(x, p=self.dropout, training=self.training)
return x
# MIT License
# Copyright (c) 2023 Alexander Tong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Alexander Tong]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
# This modified file is released under the same license.
import math
import torch
from typing import Union
from torch.distributions import LogisticNormal
class LogitNormalTrainingTimesteps:
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
assert T > 0
self.T = T
self.dist = LogisticNormal(loc, scale)
def sample(self, size, device):
t = self.dist.sample(size)[..., 0].to(device)
return t
def pad_t_like_x(t, x):
"""Function to reshape the time vector t by the number of dimensions of x.
Parameters
----------
x : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
-------
t : Tensor, shape (bs, number of x dimensions)
Example
-------
x: Tensor (bs, C, W, H)
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))
class ConditionalFlowMatcher:
"""Base class for conditional flow matching methods. This class implements the independent
conditional flow matching methods from [1] and serves as a parent class for all other flow
matching methods.
It implements:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$
"""
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : Union[float, int]
"""
self.sigma = sigma
self.time_sampler = LogitNormalTrainingTimesteps()
def compute_mu_t(self, x0, x1, t):
"""
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: t * x1 + (1 - t) * x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
def compute_sigma_t(self, t):
"""
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
standard deviation sigma
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t
return self.sigma
def sample_xt(self, x0, x1, t, epsilon):
"""
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
epsilon : Tensor, shape (bs, *dim)
noise sample from N(0, 1)
Returns
-------
xt : Tensor, shape (bs, *dim)
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
sigma_t = pad_t_like_x(sigma_t, x0)
return mu_t + sigma_t * epsilon
def compute_conditional_flow(self, x0, x1, t, xt):
"""
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field ut(x1|x0) = x1 - x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t, xt
return x1 - x0
def sample_noise_like(self, x):
return torch.randn_like(x)
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
if t is None:
# t = torch.rand(x0.shape[0]).type_as(x0)
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"
eps = self.sample_noise_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
return t, xt, ut, eps
else:
return t, xt, ut
def compute_lambda(self, t):
"""Compute the lambda function, see Eq.(23) [3].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
lambda : score weighting function
References
----------
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
def compute_mu_t(self, x0, x1, t):
r"""Compute the mean of the probability path (Eq.5) from [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
def compute_conditional_flow(self, x0, x1, t, xt):
r"""Compute the conditional vector field similar to [3].
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
see Eq.(21) [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
del xt
t = pad_t_like_x(t, x0)
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
class Diffusion(nn.Module):
def __init__(self):
super().__init__()
# Hparams
# cond dim
self.local_cond_dim = 512
self.ctx_mask_dim = 16
self.in_channels = 32
self.out_channels = 32
# LLM
self.encoder_dim = 1024
self.encoder_n_layers = 24
self.encoder_n_heads = 16
self.max_seq_len = 16384
self.multiple_of = 256
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
self.local_cond_project = nn.Linear(
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
# text encoder
self.ph_encoder = RelTransformerEncoder(
302, self.encoder_dim, self.encoder_dim,
self.encoder_dim * 2, 4, 6,
3, 0.0, prenet=True, pre_ln=True)
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
self.ph_pos_embed = PosEmb(self.encoder_dim)
self.ling_pre_net = torch.nn.Sequential(*[
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
for i, s in enumerate([2, 2])
])
def forward(self, inputs, sigmas=None, x_noisy=None):
ctx_mask = inputs['ctx_mask']
ctx_feature = inputs['lat_ctx'] * ctx_mask
""" local conditioning (prompt_latent + spk_embed) """
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
""" diffusion target latent """
x = inputs['lat']
# Here, x is x1 in CFM
x0 = torch.randn_like(x)
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
# define noisy_input and target
t = t.bfloat16()
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
target = ut
# concat condition.
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
pred = self.postnet(encoder_out)
return pred, target
def forward_ling_encoder(self, txt_tokens, tone_tokens):
ph_tokens = txt_tokens
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
# enc_ph
ph_enc_oembed = self.tone_embed(tone_tokens)
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
return x_ling
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
x = x * (1 - ctx_mask)
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
pred = self.postnet(pred_v)
""" Perform multi-cond CFG """
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
return pred
@torch.no_grad()
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
# txt embedding
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
# speaker embedding
ctx_feature = inputs['lat_ctx']
ctx_feature[1:, :, :] = 0 # prefix spk cfg
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
# local conditioning.
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
''' Euler ODE solver '''
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
sway_sampling_coef = -1.0
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
if sway_sampling_coef is not None:
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
def amo_sampling(z_t, t, t_next, v):
# Upcast to avoid precision issues when computing prev_sample
z_t = z_t.to(torch.float32)
# Constant definition in Algorithm 1
s = t_next
c = 3
# Line 7 in Algorithm 1
o = min(t_next + c * (t_next - t), 1)
pred_z_o = z_t + (o - t) * v
# Line 11 in Algorithm 1
a = s / o
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
z_t_next = a * pred_z_o + b * noise_i
return z_t_next.to(v.dtype)
x = torch.randn([1, frm_len, self.out_channels], device=device)
for step_index in range(timesteps):
x = x.to(torch.float32)
sigma = t_schedule[step_index].to(x_ling.dtype)
sigma_next = t_schedule[step_index + 1]
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
x = amo_sampling(x, sigma, sigma_next, model_out)
# Cast sample back to model compatible dtype
x = x.to(model_out.dtype)
return x
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class AdaLNZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLNZero_Out(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class Attention(nn.Module):
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
self.encoder_n_kv_heads = encoder_n_heads
model_parallel_size = 1
self.n_local_heads = encoder_n_heads // model_parallel_size
self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = encoder_dim // encoder_n_heads
self.wq = nn.Linear(
encoder_dim,
encoder_n_heads * self.head_dim,
)
self.wk = nn.Linear(
encoder_dim,
self.encoder_n_kv_heads * self.head_dim,
)
self.wv = nn.Linear(
encoder_dim,
self.encoder_n_kv_heads * self.head_dim,
)
self.wo = nn.Linear(
encoder_n_heads * self.head_dim,
encoder_dim,
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(
dim, hidden_dim
)
self.w2 = nn.Linear(
hidden_dim, dim
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)))
class TransformerBlock(nn.Module):
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
self.encoder_n_heads = encoder_n_heads
self.encoder_dim = encoder_dim
self.head_dim = encoder_dim // encoder_n_heads
self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
self.feed_forward = FeedForward(
dim=encoder_dim,
hidden_dim=2 * encoder_dim,
multiple_of=256,
ffn_dim_multiplier=None,
)
self.attention_norm = AdaLNZero(encoder_dim)
self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for attention caching.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
# attention
attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
# process attention output for input x
h = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.feed_forward(norm)
out = h + gate_mlp.unsqueeze(1) * ff_output
return out
class Transformer(nn.Module):
def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
# Decoder
self.layers = torch.nn.ModuleList()
for _ in range(encoder_n_layers):
self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
self.norm = AdaLNZero_Out(encoder_dim)
self.out_proj = nn.Linear(encoder_dim, encoder_dim)
# Rope embedding
freqs_cis = precompute_freqs_cis(
encoder_dim // encoder_n_heads, max_seq_len
)
self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
def forward(self, x, t, attn_mask, start_pos=0):
freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
for i, layer in enumerate(self.layers):
x = layer(x, t, start_pos, freqs_cis, attn_mask)
x = self.norm(x, t)
x = self.out_proj(x)
return x
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator=None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
else:
return 0.5 * (
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar
)
def nll(self, sample, dims) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.nn import Conv1d
import numpy as np
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
class Upsample(nn.Module):
def __init__(self, mult, r):
super(Upsample, self).__init__()
self.r = r
self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
)
r_kernel = r if r >= 5 else 5
self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
kernel_size=r_kernel * 2, stride=r,
padding=r_kernel - r // 2,
output_padding=r % 2)
))
def forward(self, x):
x = torch.sin(x) + x
out1 = self.upsample(x)
out2 = self.trans_upsample(x)
return out1 + out2
class Downsample(nn.Module):
def __init__(self, mult, r):
super(Downsample, self).__init__()
self.r = r
r_kernel = r if r >= 5 else 5
self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
kernel_size=r_kernel * 2, stride=r,
padding=r_kernel - r // 2)
))
def forward(self, x):
out = self.trans_downsample(x)
return out
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def weights_zero_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class Audio2Mel(nn.Module):
def __init__(
self,
hop_length=300,
sampling_rate=24000,
n_mel_channels=80,
mel_fmin=0.,
mel_fmax=None,
frame_size=0.05,
device='cpu'
):
super().__init__()
##############################################
# FFT Parameters #
##############################################
self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
window = torch.hann_window(int(sampling_rate * frame_size)).float()
mel_basis = librosa_mel_fn(
sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
) # Mel filter (by librosa)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.hop_length = hop_length
self.win_length = int(sampling_rate * frame_size)
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
def forward(self, audio):
fft = torch.stft(
audio.squeeze(1),
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
norm_mel = (log_mel_spec + 115.) / 115.
mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
return mel_comp
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1, dim_in=None):
super().__init__()
if dim_in is None:
dim_in = dim
self.block = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
nn.LeakyReLU(0.2),
WNConv1d(dim, dim, kernel_size=1),
)
self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
def forward(self, x):
return self.shortcut(x) + self.block(x)
'''
参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
'''
class ResBlockMRFV2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlockMRFV2, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, 0.2)
xt = c1(xt)
xt = F.leaky_relu(xt, 0.2)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlockMRFV2Inter(torch.nn.Module):
def __init__(self, channels, kernel_size=3):
super(ResBlockMRFV2Inter, self).__init__()
self.block1 = ResBlockMRFV2(channels)
self.block2 = ResBlockMRFV2(channels, 7)
self.block3 = ResBlockMRFV2(channels, 11)
def forward(self, x):
xs = self.block1(x)
xs += self.block2(x)
xs += self.block3(x)
x = xs / 3
return x
class Generator(nn.Module):
def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
device='cpu'):
super().__init__()
self.hop_length = args.frame_shift
self.args = args
self.onnx_export = onnx_export
# ------------- Define upsample layers ----------------
mult = int(2 ** len(ratios))
model_up = []
input_size = input_size_
model_up += [
nn.ReflectionPad1d(3),
WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
]
# Upsample to raw audio scale
for i, r in enumerate(ratios):
model_up += [Upsample(mult * ngf, r)]
model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
mult //= 2
model_up += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
WNConv1d(ngf, num_band, kernel_size=7, padding=0),
nn.Tanh(),
]
if not args.use_tanh:
model_up[-1] = nn.Conv1d(num_band, num_band, 1)
model_up[-2].apply(weights_zero_init)
self.model_up = nn.Sequential(*model_up)
self.apply(weights_init)
def forward(self, mel, step=None):
# mel input: (batch_size, seq_num, 80)
if self.onnx_export:
mel = mel.transpose(1, 2)
# on onnx, for engineering, mel input: (batch_size, 80, seq_num)
# Between Down and up
x = mel
# Upsample pipline
cnt_after_upsample = 0
for i, m in enumerate(self.model_up):
x = m(x)
if type(m) == Upsample:
cnt_after_upsample += 1
return x
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from torch import nn
from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
class Encoder(nn.Module):
def __init__(
self,
dowmsamples: List[int] = [6, 5, 5, 4, 2],
):
super().__init__()
# breakpoint()
self.frame_rate = 25 # not use
self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
def forward(self, audio: torch.Tensor):
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encoder(audio)
return emb
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from torch import nn
import torch.nn.functional as F
from tts.modules.wavvae.decoder.seanet_encoder import Encoder
from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
class WavVAE_V3(nn.Module):
def __init__(self, hparams=None):
super().__init__()
self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
self.proj_to_z = nn.Linear(512, 64)
self.proj_to_decoder = nn.Linear(32, 320)
config_path = hparams['melgan_config']
args = argparse.Namespace()
args.__dict__.update(config_path)
self.latent_upsampler = Upsample(320, 4)
self.decoder = Generator(
input_size_=160, ngf=128, n_residual_layers=4,
num_band=1, args=args, ratios=[5,4,4,3])
''' encode waveform into 25 hz latent representation '''
def encode_latent(self, audio):
posterior = self.encode(audio)
latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
return latent
def encode(self, audio):
x = self.encoder(audio).permute(0, 2, 1)
x = self.proj_to_z(x).permute(0, 2, 1)
poseterior = DiagonalGaussianDistribution(x)
return poseterior
def decode(self, latent):
latent = self.proj_to_decoder(latent).permute(0, 2, 1)
return self.decoder(self.latent_upsampler(latent))
def forward(self, audio):
posterior = self.encode(audio)
latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
recon_wav = self.decode(latent)
return recon_wav, posterior
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment