Commit 719799a2 authored by lidc's avatar lidc
Browse files

增加了pytorch框架下的音频处理模型FastSpeech和ECAPA-TDNN的测试代码

parent 13a50bfe
import torch
import numpy as np
import librosa.util as librosa_util
from scipy.signal import get_window
def window_sumsquare(
window,
n_frames,
hop_length,
win_length,
n_fft,
dtype=np.float32,
norm=None,
):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
return x
def griffin_lim(magnitudes, stft_fn, n_iters=30):
"""
PARAMS
------
magnitudes: spectrogram magnitudes
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
"""
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
angles = angles.astype(np.float32)
angles = torch.autograd.Variable(torch.from_numpy(angles))
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
for i in range(n_iters):
_, angles = stft_fn.transform(signal)
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
return signal
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
import torch
import torch.nn.functional as F
import numpy as np
from scipy.signal import get_window
from librosa.util import pad_center, tiny
from librosa.filters import mel as librosa_mel_fn
from audio.audio_processing import (
dynamic_range_compression,
dynamic_range_decompression,
window_sumsquare,
)
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length, hop_length, win_length, window="hann"):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
)
if window is not None:
assert filter_length >= win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode="reflect",
)
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data.cuda(),
torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
stride=self.hop_length,
padding=0,
).cpu()
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
if self.window is not None:
window_sum = window_sumsquare(
self.window,
magnitude.size(-1),
hop_length=self.hop_length,
win_length=self.win_length,
n_fft=self.filter_length,
dtype=np.float32,
)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0]
)
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False
)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class TacotronSTFT(torch.nn.Module):
def __init__(
self,
filter_length,
hop_length,
win_length,
n_mel_channels,
sampling_rate,
mel_fmin,
mel_fmax,
):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
def spectral_normalize(self, magnitudes):
output = dynamic_range_compression(magnitudes)
return output
def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output
def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert torch.min(y.data) >= -1
assert torch.max(y.data) <= 1
magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
energy = torch.norm(magnitudes, dim=1)
return mel_output, energy
import torch
import numpy as np
from scipy.io.wavfile import write
from audio.audio_processing import griffin_lim
def get_mel_from_wav(audio, _stft):
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
audio = torch.autograd.Variable(audio, requires_grad=False)
melspec, energy = _stft.mel_spectrogram(audio)
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
return melspec, energy
def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
mel = torch.stack([mel])
mel_decompress = _stft.spectral_de_normalize(mel)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
audio_path = out_filename
write(audio_path, _stft.sampling_rate, audio)
transformer:
encoder_layer: 4
encoder_head: 2
encoder_hidden: 256
decoder_layer: 6
decoder_head: 2
decoder_hidden: 256
conv_filter_size: 1024
conv_kernel_size: [9, 1]
encoder_dropout: 0.2
decoder_dropout: 0.2
variance_predictor:
filter_size: 256
kernel_size: 3
dropout: 0.5
variance_embedding:
pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
n_bins: 256
# gst:
# use_gst: False
# conv_filters: [32, 32, 64, 64, 128, 128]
# gru_hidden: 128
# token_size: 128
# n_style_token: 10
# attn_head: 4
multi_speaker: True
max_seq_len: 1000
vocoder:
model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
speaker: "universal" # support 'LJSpeech', 'universal'
dataset: "AISHELL3"
path:
corpus_path: "/home/ming/Data/AISHELL-3"
lexicon_path: "lexicon/pinyin-lexicon-r.txt"
raw_path: "./raw_data/AISHELL3"
preprocessed_path: "./preprocessed_data/AISHELL3"
preprocessing:
val_size: 512
text:
text_cleaners: []
language: "zh"
audio:
sampling_rate: 22050
max_wav_value: 32768.0
stft:
filter_length: 1024
hop_length: 256
win_length: 1024
mel:
n_mel_channels: 80
mel_fmin: 0
mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
pitch:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
energy:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
path:
ckpt_path: "./output/ckpt/AISHELL3"
log_path: "./output/log/AISHELL3"
result_path: "./output/result/AISHELL3"
optimizer:
batch_size: 16
betas: [0.9, 0.98]
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: [300000, 400000, 500000]
anneal_rate: 0.3
step:
total_step: 900000
log_step: 100
synth_step: 1000
val_step: 1000
save_step: 100000
transformer:
encoder_layer: 4
encoder_head: 2
encoder_hidden: 256
decoder_layer: 6
decoder_head: 2
decoder_hidden: 256
conv_filter_size: 1024
conv_kernel_size: [9, 1]
encoder_dropout: 0.2
decoder_dropout: 0.2
variance_predictor:
filter_size: 256
kernel_size: 3
dropout: 0.5
variance_embedding:
pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
n_bins: 256
# gst:
# use_gst: False
# conv_filters: [32, 32, 64, 64, 128, 128]
# gru_hidden: 128
# token_size: 128
# n_style_token: 10
# attn_head: 4
multi_speaker: False
max_seq_len: 1000
vocoder:
model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
speaker: "LJSpeech" # support 'LJSpeech', 'universal'
dataset: "LJSpeech"
path:
corpus_path: "/home/ming/Data/LJSpeech-1.1"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LJSpeech"
preprocessed_path: "./preprocessed_data/LJSpeech"
preprocessing:
val_size: 512
text:
text_cleaners: ["english_cleaners"]
language: "en"
audio:
sampling_rate: 22050
max_wav_value: 32768.0
stft:
filter_length: 1024
hop_length: 256
win_length: 1024
mel:
n_mel_channels: 80
mel_fmin: 0
mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
pitch:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
energy:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
path:
ckpt_path: "./output/ckpt/LJSpeech"
log_path: "./output/log/LJSpeech"
result_path: "./output/result/LJSpeech"
optimizer:
batch_size: 16
betas: [0.9, 0.98]
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: [300000, 400000, 500000]
anneal_rate: 0.3
step:
total_step: 900000
log_step: 100
synth_step: 1000
val_step: 1000
save_step: 100000
transformer:
encoder_layer: 4
encoder_head: 2
encoder_hidden: 256
decoder_layer: 4
decoder_head: 2
decoder_hidden: 256
conv_filter_size: 1024
conv_kernel_size: [9, 1]
encoder_dropout: 0.2
decoder_dropout: 0.2
variance_predictor:
filter_size: 256
kernel_size: 3
dropout: 0.5
variance_embedding:
pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
n_bins: 256
# gst:
# use_gst: False
# conv_filters: [32, 32, 64, 64, 128, 128]
# gru_hidden: 128
# token_size: 128
# n_style_token: 10
# attn_head: 4
multi_speaker: False
max_seq_len: 1000
vocoder:
model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
speaker: "LJSpeech" # support 'LJSpeech', 'universal'
dataset: "LJSpeech_paper"
path:
corpus_path: "/home/ming/Data/LJSpeech-1.1"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LJSpeech"
preprocessed_path: "./preprocessed_data/LJSpeech_paper"
preprocessing:
val_size: 512
text:
text_cleaners: ["english_cleaners"]
language: "en"
audio:
sampling_rate: 22050
max_wav_value: 32768.0
stft:
filter_length: 1024
hop_length: 256
win_length: 1024
mel:
n_mel_channels: 80
mel_fmin: 0
mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
pitch:
feature: "frame_level" # support 'phoneme_level' or 'frame_level'
normalization: False
energy:
feature: "frame_level" # support 'phoneme_level' or 'frame_level'
normalization: False
path:
ckpt_path: "./output/ckpt/LJSpeech"
log_path: "./output/log/LJSpeech"
result_path: "./output/result/LJSpeech"
optimizer:
batch_size: 48
betas: [0.9, 0.98]
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: []
anneal_rate: 1.0
step:
total_step: 160000
log_step: 100
synth_step: 1000
val_step: 1000
save_step: 10000
transformer:
encoder_layer: 4
encoder_head: 2
encoder_hidden: 256
decoder_layer: 6
decoder_head: 2
decoder_hidden: 256
conv_filter_size: 1024
conv_kernel_size: [9, 1]
encoder_dropout: 0.2
decoder_dropout: 0.2
variance_predictor:
filter_size: 256
kernel_size: 3
dropout: 0.5
variance_embedding:
pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
n_bins: 256
# gst:
# use_gst: False
# conv_filters: [32, 32, 64, 64, 128, 128]
# gru_hidden: 128
# token_size: 128
# n_style_token: 10
# attn_head: 4
multi_speaker: True
max_seq_len: 1000
vocoder:
model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
speaker: "universal" # support 'LJSpeech', 'universal'
dataset: "LibriTTS"
path:
corpus_path: "/home/ming/Data/LibriTTS/train-clean-360"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LibriTTS"
preprocessed_path: "./preprocessed_data/LibriTTS"
preprocessing:
val_size: 512
text:
text_cleaners: ["english_cleaners"]
language: "en"
audio:
sampling_rate: 22050
max_wav_value: 32768.0
stft:
filter_length: 1024
hop_length: 256
win_length: 1024
mel:
n_mel_channels: 80
mel_fmin: 0
mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
pitch:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
energy:
feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
normalization: True
path:
ckpt_path: "./output/ckpt/LibriTTS"
log_path: "./output/log/LibriTTS"
result_path: "./output/result/LibriTTS"
optimizer:
batch_size: 16
betas: [0.9, 0.98]
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: [300000, 400000, 500000]
anneal_rate: 0.3
step:
total_step: 900000
log_step: 100
synth_step: 1000
val_step: 1000
save_step: 100000
# Config
Here are the config files used to train the single/multi-speaker TTS models.
4 different configurations are given:
- LJSpeech: suggested configuration for LJSpeech dataset.
- LibriTTS: suggested configuration for LibriTTS dataset.
- AISHELL3: suggested configuration for AISHELL-3 dataset.
- LJSpeech_paper: closed to the setting proposed in the original FastSpeech 2 paper.
Some important hyper-parameters are explained here.
## preprocess.yaml
- **path.lexicon_path**: the lexicon (which maps words to phonemes) used by Montreal Forced Aligner.
We provide an English lexicon and a Mandarin lexicon.
Erhua (ㄦ化音) is handled in the Mandarin lexicon.
- **mel.stft.mel_fmax**: set it to 8000 if HiFi-GAN vocoder is used, and set it to null if MelGAN is used.
- **pitch.feature & energy.feature**: the original paper proposed to predict and apply frame-level pitch and energy features to the inputs of the TTS decoder to control the pitch and energy of the synthesized utterances.
However, in our experiments, we find that using phoneme-level features makes the prosody of the synthesized utterances more natural.
- **pitch.normalization & energy.normalization**: to normalize the pitch and energy values or not.
The original paper did not normalize these values.
## train.yaml
- **optimizer.grad_acc_step**: the number of batches of gradient accumulation before updating the model parameters and call optimizer.zero_grad(), which is useful if you wish to train the model with a large batch size but you do not have sufficient GPU memory.
- **optimizer.anneal_steps & optimizer.anneal_rate**: the learning rate is reduced at the **anneal_steps** by the ratio specified with **anneal_rate**.
## model.yaml
- **transformer.decoder_layer**: the original paper used a 4-layer decoder, but we find it better to use a 6-layer decoder, especially for multi-speaker TTS.
- **variance_embedding.pitch_quantization**: when the pitch values are normalized as specified in ``preprocess.yaml``, it is not valid to use log-scale quantization bins as proposed in the original paper, so we use linear-scaled bins instead.
- **multi_speaker**: to apply a speaker embedding table to enable multi-speaker TTS or not.
- **vocoder.speaker**: should be set to 'universal' if any dataset other than LJSpeech is used.
\ No newline at end of file
import json
import math
import os
import numpy as np
from torch.utils.data import Dataset
from text import text_to_sequence
from utils.tools import pad_1D, pad_2D
class Dataset(Dataset):
def __init__(
self, filename, preprocess_config, train_config, sort=False, drop_last=False
):
self.dataset_name = preprocess_config["dataset"]
self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.batch_size = train_config["optimizer"]["batch_size"]
self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
filename
)
with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
self.speaker_map = json.load(f)
self.sort = sort
self.drop_last = drop_last
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
mel_path = os.path.join(
self.preprocessed_path,
"mel",
"{}-mel-{}.npy".format(speaker, basename),
)
mel = np.load(mel_path)
pitch_path = os.path.join(
self.preprocessed_path,
"pitch",
"{}-pitch-{}.npy".format(speaker, basename),
)
pitch = np.load(pitch_path)
energy_path = os.path.join(
self.preprocessed_path,
"energy",
"{}-energy-{}.npy".format(speaker, basename),
)
energy = np.load(energy_path)
duration_path = os.path.join(
self.preprocessed_path,
"duration",
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path)
sample = {
"id": basename,
"speaker": speaker_id,
"text": phone,
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
"energy": energy,
"duration": duration,
}
return sample
def process_meta(self, filename):
with open(
os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
) as f:
name = []
speaker = []
text = []
raw_text = []
for line in f.readlines():
n, s, t, r = line.strip("\n").split("|")
name.append(n)
speaker.append(s)
text.append(t)
raw_text.append(r)
return name, speaker, text, raw_text
def reprocess(self, data, idxs):
ids = [data[idx]["id"] for idx in idxs]
speakers = [data[idx]["speaker"] for idx in idxs]
texts = [data[idx]["text"] for idx in idxs]
raw_texts = [data[idx]["raw_text"] for idx in idxs]
mels = [data[idx]["mel"] for idx in idxs]
pitches = [data[idx]["pitch"] for idx in idxs]
energies = [data[idx]["energy"] for idx in idxs]
durations = [data[idx]["duration"] for idx in idxs]
text_lens = np.array([text.shape[0] for text in texts])
mel_lens = np.array([mel.shape[0] for mel in mels])
speakers = np.array(speakers)
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
energies = pad_1D(energies)
durations = pad_1D(durations)
return (
ids,
raw_texts,
speakers,
texts,
text_lens,
max(text_lens),
mels,
mel_lens,
max(mel_lens),
pitches,
energies,
durations,
)
def collate_fn(self, data):
data_size = len(data)
if self.sort:
len_arr = np.array([d["text"].shape[0] for d in data])
idx_arr = np.argsort(-len_arr)
else:
idx_arr = np.arange(data_size)
tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
if not self.drop_last and len(tail) > 0:
idx_arr += [tail.tolist()]
output = list()
for idx in idx_arr:
output.append(self.reprocess(data, idx))
return output
class TextDataset(Dataset):
def __init__(self, filepath, preprocess_config):
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
filepath
)
with open(
os.path.join(
preprocess_config["path"]["preprocessed_path"], "speakers.json"
)
) as f:
self.speaker_map = json.load(f)
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
return (basename, speaker_id, phone, raw_text)
def process_meta(self, filename):
with open(filename, "r", encoding="utf-8") as f:
name = []
speaker = []
text = []
raw_text = []
for line in f.readlines():
n, s, t, r = line.strip("\n").split("|")
name.append(n)
speaker.append(s)
text.append(t)
raw_text.append(r)
return name, speaker, text, raw_text
def collate_fn(self, data):
ids = [d[0] for d in data]
speakers = np.array([d[1] for d in data])
texts = [d[2] for d in data]
raw_texts = [d[3] for d in data]
text_lens = np.array([text.shape[0] for text in texts])
texts = pad_1D(texts)
return ids, raw_texts, speakers, texts, text_lens, max(text_lens)
if __name__ == "__main__":
# Test
import torch
import yaml
from torch.utils.data import DataLoader
from utils.utils import to_device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocess_config = yaml.load(
open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader
)
train_config = yaml.load(
open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader
)
train_dataset = Dataset(
"train.txt", preprocess_config, train_config, sort=True, drop_last=True
)
val_dataset = Dataset(
"val.txt", preprocess_config, train_config, sort=False, drop_last=False
)
train_loader = DataLoader(
train_dataset,
batch_size=train_config["optimizer"]["batch_size"] * 4,
shuffle=True,
collate_fn=train_dataset.collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=train_config["optimizer"]["batch_size"],
shuffle=False,
collate_fn=val_dataset.collate_fn,
)
n_batch = 0
for batchs in train_loader:
for batch in batchs:
to_device(batch, device)
n_batch += 1
print(
"Training set with size {} is composed of {} batches.".format(
len(train_dataset), n_batch
)
)
n_batch = 0
for batchs in val_loader:
for batch in batchs:
to_device(batch, device)
n_batch += 1
print(
"Validation set with size {} is composed of {} batches.".format(
len(val_dataset), n_batch
)
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment