Commit e6e33f1a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2698 canceled with stages
This diff is collapsed.
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import torch
import torch.nn.functional as F
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype)
return mask
def weights_nonzero_speech(target):
# target : B x T x mel
# Assign weight 1.0 to all labels except for padding (id=0).
dim = target.size(-1)
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def select_attn(attn_logits, type='best'):
"""
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
:return:
"""
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
# [n_layers * n_head, B, T_sp, T_txt]
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
if type == 'best':
indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
encdec_attn = encdec_attn.gather(
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
return encdec_attn
elif type == 'mean':
return encdec_attn.mean(0)
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
def group_hidden_by_segs(h, seg_ids, max_len):
"""
:param h: [B, T, H]
:param seg_ids: [B, T]
:return: h_ph: [B, T_ph, H]
"""
B, T, H = h.shape
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
all_ones = h.new_ones(h.shape[:2])
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
h_gby_segs = h_gby_segs[:, 1:]
cnt_gby_segs = cnt_gby_segs[:, 1:]
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
return h_gby_segs, cnt_gby_segs
def expand_by_repeat_times(source_encoding, lengths):
"""
source_encoding: [T, C]
lengths, list of int, [T,], how many times each token should repeat
return:
expanded_encoding: [T_expand, C]
"""
hid_dim = source_encoding.shape[1]
out2source = []
for i, length in enumerate(lengths):
out2source += [i for _ in range(length)]
out2source = torch.LongTensor(out2source).to(source_encoding.device)
out2source_ = out2source[:, None].repeat([1, hid_dim])
expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
return expanded_encoding
def expand_word2ph(word_encoding, ph2word):
word_encoding = F.pad(word_encoding,[0,0,1,0])
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
return out
This diff is collapsed.
# MIT License
# Copyright (c) 2023 Alexander Tong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Alexander Tong]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
# This modified file is released under the same license.
import math
import torch
from typing import Union
from torch.distributions import LogisticNormal
class LogitNormalTrainingTimesteps:
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
assert T > 0
self.T = T
self.dist = LogisticNormal(loc, scale)
def sample(self, size, device):
t = self.dist.sample(size)[..., 0].to(device)
return t
def pad_t_like_x(t, x):
"""Function to reshape the time vector t by the number of dimensions of x.
Parameters
----------
x : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
-------
t : Tensor, shape (bs, number of x dimensions)
Example
-------
x: Tensor (bs, C, W, H)
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))
class ConditionalFlowMatcher:
"""Base class for conditional flow matching methods. This class implements the independent
conditional flow matching methods from [1] and serves as a parent class for all other flow
matching methods.
It implements:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$
"""
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : Union[float, int]
"""
self.sigma = sigma
self.time_sampler = LogitNormalTrainingTimesteps()
def compute_mu_t(self, x0, x1, t):
"""
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: t * x1 + (1 - t) * x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
def compute_sigma_t(self, t):
"""
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
standard deviation sigma
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t
return self.sigma
def sample_xt(self, x0, x1, t, epsilon):
"""
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
epsilon : Tensor, shape (bs, *dim)
noise sample from N(0, 1)
Returns
-------
xt : Tensor, shape (bs, *dim)
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
sigma_t = pad_t_like_x(sigma_t, x0)
return mu_t + sigma_t * epsilon
def compute_conditional_flow(self, x0, x1, t, xt):
"""
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field ut(x1|x0) = x1 - x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t, xt
return x1 - x0
def sample_noise_like(self, x):
return torch.randn_like(x)
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
if t is None:
# t = torch.rand(x0.shape[0]).type_as(x0)
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"
eps = self.sample_noise_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
return t, xt, ut, eps
else:
return t, xt, ut
def compute_lambda(self, t):
"""Compute the lambda function, see Eq.(23) [3].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
lambda : score weighting function
References
----------
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
def compute_mu_t(self, x0, x1, t):
r"""Compute the mean of the probability path (Eq.5) from [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
def compute_conditional_flow(self, x0, x1, t, xt):
r"""Compute the conditional vector field similar to [3].
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
see Eq.(21) [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
del xt
t = pad_t_like_x(t, x0)
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
class Diffusion(nn.Module):
def __init__(self):
super().__init__()
# Hparams
# cond dim
self.local_cond_dim = 512
self.ctx_mask_dim = 16
self.in_channels = 32
self.out_channels = 32
# LLM
self.encoder_dim = 1024
self.encoder_n_layers = 24
self.encoder_n_heads = 16
self.max_seq_len = 16384
self.multiple_of = 256
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
self.local_cond_project = nn.Linear(
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
# text encoder
self.ph_encoder = RelTransformerEncoder(
302, self.encoder_dim, self.encoder_dim,
self.encoder_dim * 2, 4, 6,
3, 0.0, prenet=True, pre_ln=True)
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
self.ph_pos_embed = PosEmb(self.encoder_dim)
self.ling_pre_net = torch.nn.Sequential(*[
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
for i, s in enumerate([2, 2])
])
def forward(self, inputs, sigmas=None, x_noisy=None):
ctx_mask = inputs['ctx_mask']
ctx_feature = inputs['lat_ctx'] * ctx_mask
""" local conditioning (prompt_latent + spk_embed) """
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
""" diffusion target latent """
x = inputs['lat']
# Here, x is x1 in CFM
x0 = torch.randn_like(x)
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
# define noisy_input and target
t = t.bfloat16()
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
target = ut
# concat condition.
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
pred = self.postnet(encoder_out)
return pred, target
def forward_ling_encoder(self, txt_tokens, tone_tokens):
ph_tokens = txt_tokens
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
# enc_ph
ph_enc_oembed = self.tone_embed(tone_tokens)
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
return x_ling
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
x = x * (1 - ctx_mask)
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
pred = self.postnet(pred_v)
""" Perform multi-cond CFG """
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
return pred
@torch.no_grad()
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
# txt embedding
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
# speaker embedding
ctx_feature = inputs['lat_ctx']
ctx_feature[1:, :, :] = 0 # prefix spk cfg
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
# local conditioning.
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
''' Euler ODE solver '''
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
sway_sampling_coef = -1.0
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
if sway_sampling_coef is not None:
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
def amo_sampling(z_t, t, t_next, v):
# Upcast to avoid precision issues when computing prev_sample
z_t = z_t.to(torch.float32)
# Constant definition in Algorithm 1
s = t_next
c = 3
# Line 7 in Algorithm 1
o = min(t_next + c * (t_next - t), 1)
pred_z_o = z_t + (o - t) * v
# Line 11 in Algorithm 1
a = s / o
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
z_t_next = a * pred_z_o + b * noise_i
return z_t_next.to(v.dtype)
x = torch.randn([1, frm_len, self.out_channels], device=device)
for step_index in range(timesteps):
x = x.to(torch.float32)
sigma = t_schedule[step_index].to(x_ling.dtype)
sigma_next = t_schedule[step_index + 1]
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
x = amo_sampling(x, sigma, sigma_next, model_out)
# Cast sample back to model compatible dtype
x = x.to(model_out.dtype)
return x
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment