Commit ace07110 authored by patil-suraj's avatar patil-suraj
Browse files

style

parent 988369a0
...@@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if n_spks > 1: if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim * 4, n_feats)) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim) self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
......
# tokenizer # tokenizer
import re
import os import os
import re
from shutil import copyfile from shutil import copyfile
import torch import torch
try: try:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
except: except:
...@@ -25,17 +26,95 @@ except: ...@@ -25,17 +26,95 @@ except:
valid_symbols = [ valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', "AA",
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', "AA0",
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', "AA1",
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', "AA2",
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', "AE",
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', "AE0",
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' "AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
] ]
_valid_symbol_set = set(valid_symbols) _valid_symbol_set = set(valid_symbols)
def intersperse(lst, item): def intersperse(lst, item):
# Adds blank symbol # Adds blank symbol
result = [item] * (len(lst) * 2 + 1) result = [item] * (len(lst) * 2 + 1)
...@@ -46,7 +125,7 @@ def intersperse(lst, item): ...@@ -46,7 +125,7 @@ def intersperse(lst, item):
class CMUDict: class CMUDict:
def __init__(self, file_or_path, keep_ambiguous=True): def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str): if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f: with open(file_or_path, encoding="latin-1") as f:
entries = _parse_cmudict(f) entries = _parse_cmudict(f)
else: else:
entries = _parse_cmudict(file_or_path) entries = _parse_cmudict(file_or_path)
...@@ -61,15 +140,15 @@ class CMUDict: ...@@ -61,15 +140,15 @@ class CMUDict:
return self._entries.get(word.upper()) return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)') _alt_re = re.compile(r"\([0-9]+\)")
def _parse_cmudict(file): def _parse_cmudict(file):
cmudict = {} cmudict = {}
for line in file: for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
parts = line.split(' ') parts = line.split(" ")
word = re.sub(_alt_re, '', parts[0]) word = re.sub(_alt_re, "", parts[0])
pronunciation = _get_pronunciation(parts[1]) pronunciation = _get_pronunciation(parts[1])
if pronunciation: if pronunciation:
if word in cmudict: if word in cmudict:
...@@ -80,36 +159,38 @@ def _parse_cmudict(file): ...@@ -80,36 +159,38 @@ def _parse_cmudict(file):
def _get_pronunciation(s): def _get_pronunciation(s):
parts = s.strip().split(' ') parts = s.strip().split(" ")
for part in parts: for part in parts:
if part not in _valid_symbol_set: if part not in _valid_symbol_set:
return None return None
return ' '.join(parts) return " ".join(parts)
_whitespace_re = re.compile(r"\s+")
_whitespace_re = re.compile(r'\s+')
_abbreviations = [
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
('mrs', 'misess'), for x in [
('mr', 'mister'), ("mrs", "misess"),
('dr', 'doctor'), ("mr", "mister"),
('st', 'saint'), ("dr", "doctor"),
('co', 'company'), ("st", "saint"),
('jr', 'junior'), ("co", "company"),
('maj', 'major'), ("jr", "junior"),
('gen', 'general'), ("maj", "major"),
('drs', 'doctors'), ("gen", "general"),
('rev', 'reverend'), ("drs", "doctors"),
('lt', 'lieutenant'), ("rev", "reverend"),
('hon', 'honorable'), ("lt", "lieutenant"),
('sgt', 'sergeant'), ("hon", "honorable"),
('capt', 'captain'), ("sgt", "sergeant"),
('esq', 'esquire'), ("capt", "captain"),
('ltd', 'limited'), ("esq", "esquire"),
('col', 'colonel'), ("ltd", "limited"),
('ft', 'fort'), ("col", "colonel"),
]] ("ft", "fort"),
]
]
def expand_abbreviations(text): def expand_abbreviations(text):
...@@ -127,7 +208,7 @@ def lowercase(text): ...@@ -127,7 +208,7 @@ def lowercase(text):
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text) return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text): def convert_to_ascii(text):
...@@ -156,46 +237,42 @@ def english_cleaners(text): ...@@ -156,46 +237,42 @@ def english_cleaners(text):
return text return text
_inflect = inflect.engine() _inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r"[0-9]+")
def _remove_commas(m): def _remove_commas(m):
return m.group(1).replace(',', '') return m.group(1).replace(",", "")
def _expand_decimal_point(m): def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ') return m.group(1).replace(".", " point ")
def _expand_dollars(m): def _expand_dollars(m):
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split(".")
if len(parts) > 2: if len(parts) > 2:
return match + ' dollars' return match + " dollars"
dollars = int(parts[0]) if parts[0] else 0 dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents: if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars: elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
return '%s %s' % (dollars, dollar_unit) return "%s %s" % (dollars, dollar_unit)
elif cents: elif cents:
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s' % (cents, cent_unit) return "%s %s" % (cents, cent_unit)
else: else:
return 'zero dollars' return "zero dollars"
def _expand_ordinal(m): def _expand_ordinal(m):
...@@ -206,37 +283,37 @@ def _expand_number(m): ...@@ -206,37 +283,37 @@ def _expand_number(m):
num = int(m.group(0)) num = int(m.group(0))
if num > 1000 and num < 3000: if num > 1000 and num < 3000:
if num == 2000: if num == 2000:
return 'two thousand' return "two thousand"
elif num > 2000 and num < 2010: elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100) return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred' return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words(num, andword='', zero='oh', return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
group=2).replace(', ', ' ')
else: else:
return _inflect.number_to_words(num, andword='') return _inflect.number_to_words(num, andword="")
def normalize_numbers(text): def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text) text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text) text = re.sub(_number_re, _expand_number, text)
return text return text
""" from https://github.com/keithito/tacotron """ """ from https://github.com/keithito/tacotron """
_pad = '_' _pad = "_"
_punctuation = '!\'(),.:;? ' _punctuation = "!'(),.:;? "
_special = '-' _special = "-"
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness: # Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet = ['@' + s for s in valid_symbols] _arpabet = ["@" + s for s in valid_symbols]
# Export all symbols: # Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
...@@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab ...@@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)}
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def get_arpabet(word, dictionary): def get_arpabet(word, dictionary):
...@@ -257,7 +334,7 @@ def get_arpabet(word, dictionary): ...@@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
...@@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): ...@@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' """
sequence = [] sequence = []
space = _symbols_to_sequence(' ') space = _symbols_to_sequence(" ")
# Check for curly braces and treat their contents as ARPAbet: # Check for curly braces and treat their contents as ARPAbet:
while len(text): while len(text):
m = _curly_re.match(text) m = _curly_re.match(text)
...@@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): ...@@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2)) sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3) text = m.group(3)
# remove trailing space # remove trailing space
if dictionary is not None: if dictionary is not None:
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
...@@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): ...@@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
def sequence_to_text(sequence): def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string''' """Converts a sequence of IDs back to a string"""
result = '' result = ""
for symbol_id in sequence: for symbol_id in sequence:
if symbol_id in _id_to_symbol: if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id] s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces: # Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@': if len(s) > 1 and s[0] == "@":
s = '{%s}' % s[1:] s = "{%s}" % s[1:]
result += s result += s
return result.replace('}{', ' ') return result.replace("}{", " ")
def _clean_text(text, cleaner_names): def _clean_text(text, cleaner_names):
...@@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols): ...@@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
def _arpabet_to_sequence(text): def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()]) return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s): def _should_keep_symbol(s):
return s in _symbol_to_id and s != '_' and s != '~' return s in _symbol_to_id and s != "_" and s != "~"
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
"dict_file": "dict_file.txt", "dict_file": "dict_file.txt",
} }
class GradTTSTokenizer(PreTrainedTokenizer): class GradTTSTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer): ...@@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.cmu = CMUDict(dict_file) self.cmu = CMUDict(dict_file)
self.dict_file = dict_file self.dict_file = dict_file
def __call__(self, text): def __call__(self, text):
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]]) x_lengths = torch.LongTensor([x.shape[-1]])
return x, x_lengths return x, x_lengths
def save_vocabulary(self, save_directory: str, filename_prefix = None): def save_vocabulary(self, save_directory: str, filename_prefix=None):
dict_file = os.path.join( dict_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
) )
copyfile(self.dict_file, dict_file) copyfile(self.dict_file, dict_file)
return (dict_file, ) return (dict_file,)
...@@ -4,13 +4,13 @@ import math ...@@ -4,13 +4,13 @@ import math
import torch import torch
from torch import nn from torch import nn
import tqdm
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers import DiffusionPipeline
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
...@@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.window_size = window_size self.window_size = window_size
self.spk_emb_dim = spk_emb_dim self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, n_channels) self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
...@@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
) )
def forward(self, x, x_lengths, spk=None): def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels) x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1) x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
...@@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline): ...@@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) self.register_modules(
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
)
@torch.no_grad() @torch.no_grad()
def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None): def __call__(
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
):
if torch_device is None: if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.unet.to(torch_device) self.unet.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
x, x_lengths = self.tokenizer(text) x, x_lengths = self.tokenizer(text)
x = x.to(torch_device) x = x.to(torch_device)
x_lengths = x_lengths.to(torch_device) x_lengths = x_lengths.to(torch_device)
if speaker_id is not None: if speaker_id is not None:
speaker_id= torch.LongTensor([speaker_id]).to(torch_device) speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.text_encoder(x, x_lengths) mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
w = torch.exp(logw) * x_mask w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
...@@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline): ...@@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline):
# Sample latent representation from terminal distribution N(mu_y, I) # Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
xt = z * y_mask xt = z * y_mask
h = 1.0 / num_inference_steps h = 1.0 / num_inference_steps
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1) time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, y_mask, mu_y, t, speaker_id) residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
xt = xt * y_mask xt = xt * y_mask
return xt[:, :, :y_max_length] return xt[:, :, :y_max_length]
\ No newline at end of file
...@@ -19,6 +19,6 @@ ...@@ -19,6 +19,6 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_grad_tts import GradTTSScheduler from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep): def sample_noise(self, timestep):
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
return noise return noise
def step(self, xt, residual, mu, h, timestep): def step(self, xt, residual, mu, h, timestep):
noise_t = self.sample_noise(timestep) noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual) dxt = 0.5 * (mu - xt - residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment