Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
......@@ -21,11 +21,7 @@ def _parse_args():
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
parser.add_argument("input_dir", type=Path, help="Directory where `*.trans.txt` files are searched.")
return parser.parse_args()
......@@ -34,22 +30,22 @@ def _parse_transcript(path):
for line in trans_fileobj:
line = line.strip()
if line:
yield line.split(' ', maxsplit=1)
yield line.split(" ", maxsplit=1)
def _parse_directory(root_dir: Path):
for trans_file in root_dir.glob('**/*.trans.txt'):
for trans_file in root_dir.glob("**/*.trans.txt"):
trans_dir = trans_file.parent
for id_, transcription in _parse_transcript(trans_file):
audio_path = trans_dir / f'{id_}.flac'
audio_path = trans_dir / f"{id_}.flac"
yield id_, audio_path, transcription
def _main():
args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}')
print(f"{id_}\t{path}\t{transcription}")
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -12,8 +12,8 @@ example: python parse_voxforge.py voxforge/de/Helge-20150608-aku
Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/
""" # noqa: E501
import os
import argparse
import os
from pathlib import Path
......@@ -22,11 +22,7 @@ def _parse_args():
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
parser.add_argument("input_dir", type=Path, help="Directory where `*.trans.txt` files are searched.")
return parser.parse_args()
......@@ -38,19 +34,19 @@ def _parse_prompts(path):
if not line:
continue
id_, transcript = line.split(' ', maxsplit=1)
id_, transcript = line.split(" ", maxsplit=1)
if not transcript:
continue
transcript = transcript.upper()
filename = id_.split('/')[-1]
audio_path = base_dir / 'wav' / f'{filename}.wav'
filename = id_.split("/")[-1]
audio_path = base_dir / "wav" / f"{filename}.wav"
if os.path.exists(audio_path):
yield id_, audio_path, transcript
def _parse_directory(root_dir: Path):
for prompt_file in root_dir.glob('**/PROMPTS'):
for prompt_file in root_dir.glob("**/PROMPTS"):
try:
yield from _parse_prompts(prompt_file)
except UnicodeDecodeError:
......@@ -60,8 +56,8 @@ def _parse_directory(root_dir: Path):
def _main():
args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}')
print(f"{id_}\t{path}\t{transcription}")
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -29,7 +29,6 @@ from typing import Tuple, Callable, List
import torch
from torch import Tensor
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH
......@@ -45,8 +44,7 @@ class InverseSpectralNormalization(torch.nn.Module):
class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
def __init__(self, dataset):
self.dataset = dataset
......@@ -84,16 +82,17 @@ class Processed(torch.utils.data.Dataset):
return text_norm, torch.squeeze(melspec, 0)
def split_process_dataset(dataset: str,
file_path: str,
val_ratio: float,
transforms: Callable,
text_preprocessor: Callable[[str], List[int]],
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
def split_process_dataset(
dataset: str,
file_path: str,
val_ratio: float,
transforms: Callable,
text_preprocessor: Callable[[str], List[int]],
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
"""Returns the Training and validation datasets.
Args:
dataset (str): The dataset to use. Avaliable options: [`'ljspeech'`]
dataset (str): The dataset to use. Available options: [`'ljspeech'`]
file_path (str): Path to the data.
val_ratio (float): Path to the data.
transforms (callable): A function/transform that takes in a waveform and
......@@ -105,7 +104,7 @@ def split_process_dataset(dataset: str,
train_dataset (`torch.utils.data.Dataset`): The training set.
val_dataset (`torch.utils.data.Dataset`): The validation set.
"""
if dataset == 'ljspeech':
if dataset == "ljspeech":
data = LJSPEECH(root=file_path, download=False)
val_length = int(len(data) * val_ratio)
......@@ -123,8 +122,9 @@ def split_process_dataset(dataset: str,
return train_dataset, val_dataset
def text_mel_collate_fn(batch: Tuple[Tensor, Tensor],
n_frames_per_step: int = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
def text_mel_collate_fn(
batch: Tuple[Tensor, Tensor], n_frames_per_step: int = 1
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""The collate function padding and adjusting the data based on `n_frames_per_step`.
Modified from https://github.com/NVIDIA/DeepLearningExamples
......@@ -143,13 +143,14 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor],
with shape (n_batch, max of ``mel_specgram_lengths``)
"""
text_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True)
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
)
max_input_len = text_lengths[0]
text_padded = torch.zeros((len(batch), max_input_len), dtype=torch.int64)
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
......@@ -164,8 +165,8 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor],
mel_specgram_lengths = torch.LongTensor(len(batch))
for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]][1]
mel_specgram_padded[i, :, :mel.size(1)] = mel
mel_specgram_padded[i, :, : mel.size(1)] = mel
mel_specgram_lengths[i] = mel.size(1)
gate_padded[i, mel.size(1) - 1:] = 1
gate_padded[i, mel.size(1) - 1 :] = 1
return text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded
......@@ -2,19 +2,15 @@
Text-to-speech pipeline using Tacotron2.
"""
from functools import partial
import argparse
import os
import random
import sys
from functools import partial
import numpy as np
import torch
import torchaudio
import numpy as np
from torchaudio.models import Tacotron2
from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence
from datasets import InverseSpectralNormalization
from text.text_preprocessing import (
available_symbol_set,
......@@ -22,6 +18,9 @@ from text.text_preprocessing import (
get_symbol_list,
text_to_sequence,
)
from torchaudio.models import Tacotron2
from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence
def parse_args():
......@@ -33,113 +32,72 @@ def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--checkpoint-name',
"--checkpoint-name",
type=str,
default=None,
choices=list(tacotron2_config_and_urls.keys()),
help='[string] The name of the checkpoint to load.'
)
parser.add_argument(
'--checkpoint-path',
type=str,
default=None,
help='[string] Path to the checkpoint file.'
)
parser.add_argument(
'--output-path',
type=str,
default="./audio.wav",
help='[string] Path to the output .wav file.'
help="[string] The name of the checkpoint to load.",
)
parser.add_argument("--checkpoint-path", type=str, default=None, help="[string] Path to the checkpoint file.")
parser.add_argument("--output-path", type=str, default="./audio.wav", help="[string] Path to the output .wav file.")
parser.add_argument(
'--input-text',
'-i',
"--input-text",
"-i",
type=str,
default="Hello world",
help='[string] Type in something here and TTS will generate it!'
help="[string] Type in something here and TTS will generate it!",
)
parser.add_argument(
'--vocoder',
default='nvidia_waveglow',
choices=['griffin_lim', 'wavernn', 'nvidia_waveglow'],
"--vocoder",
default="nvidia_waveglow",
choices=["griffin_lim", "wavernn", "nvidia_waveglow"],
type=str,
help="Select the vocoder to use.",
)
parser.add_argument(
"--jit",
default=False,
action="store_true",
help="If used, the model and inference function is jitted."
"--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
)
preprocessor = parser.add_argument_group('text preprocessor setup')
preprocessor = parser.add_argument_group("text preprocessor setup")
preprocessor.add_argument(
'--text-preprocessor',
default='english_characters',
"--text-preprocessor",
default="english_characters",
type=str,
choices=available_symbol_set,
help='select text preprocessor to use.'
help="select text preprocessor to use.",
)
preprocessor.add_argument(
'--phonemizer',
"--phonemizer",
default="DeepPhonemizer",
type=str,
choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"'
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"',
)
preprocessor.add_argument(
'--phonemizer-checkpoint',
"--phonemizer-checkpoint",
default="./en_us_cmudict_forward.pt",
type=str,
help='the path or name of the checkpoint for the phonemizer, '
'only used when text-preprocessor is "english_phonemes"'
help="the path or name of the checkpoint for the phonemizer, "
'only used when text-preprocessor is "english_phonemes"',
)
preprocessor.add_argument(
'--cmudict-root',
default="./",
type=str,
help='the root directory for storing CMU dictionary files'
"--cmudict-root", default="./", type=str, help="the root directory for storing CMU dictionary files"
)
audio = parser.add_argument_group('audio parameters')
audio.add_argument(
'--sample-rate',
default=22050,
type=int,
help='Sampling rate'
)
audio.add_argument(
'--n-fft',
default=1024,
type=int,
help='Filter length for STFT'
)
audio.add_argument(
'--n-mels',
default=80,
type=int,
help=''
)
audio.add_argument(
'--mel-fmin',
default=0.0,
type=float,
help='Minimum mel frequency'
)
audio.add_argument(
'--mel-fmax',
default=8000.0,
type=float,
help='Maximum mel frequency'
)
audio = parser.add_argument_group("audio parameters")
audio.add_argument("--sample-rate", default=22050, type=int, help="Sampling rate")
audio.add_argument("--n-fft", default=1024, type=int, help="Filter length for STFT")
audio.add_argument("--n-mels", default=80, type=int, help="")
audio.add_argument("--mel-fmin", default=0.0, type=float, help="Minimum mel frequency")
audio.add_argument("--mel-fmax", default=8000.0, type=float, help="Maximum mel frequency")
# parameters for WaveRNN
wavernn = parser.add_argument_group('WaveRNN parameters')
wavernn = parser.add_argument_group("WaveRNN parameters")
wavernn.add_argument(
'--wavernn-checkpoint-name',
"--wavernn-checkpoint-name",
default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(wavernn_config_and_urls.keys()),
help="Select the WaveRNN checkpoint."
help="Select the WaveRNN checkpoint.",
)
wavernn.add_argument(
"--wavernn-loss",
......@@ -152,13 +110,10 @@ def parse_args():
"--wavernn-no-batch-inference",
default=False,
action="store_true",
help="Don't use batch inference for WaveRNN inference."
help="Don't use batch inference for WaveRNN inference.",
)
wavernn.add_argument(
"--wavernn-no-mulaw",
default=False,
action="store_true",
help="Don't use mulaw decoder to decode the signal."
"--wavernn-no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decode the signal."
)
wavernn.add_argument(
"--wavernn-batch-timesteps",
......@@ -187,11 +142,11 @@ def unwrap_distributed(state_dict):
unwrapped_state_dict: Unwrapped state_dict.
"""
return {k.replace('module.', ''): v for k, v in state_dict.items()}
return {k.replace("module.", ""): v for k, v in state_dict.items()}
def nvidia_waveglow_vocode(mel_specgram, device, jit=False):
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16')
waveglow = torch.hub.load("NVIDIA/DeepLearningExamples:torchhub", "nvidia_waveglow", model_math="fp16")
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()
......@@ -205,13 +160,22 @@ def nvidia_waveglow_vocode(mel_specgram, device, jit=False):
return waveform
def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_no_mulaw,
wavernn_no_batch_inference, wavernn_batch_timesteps, wavernn_batch_overlap,
device, jit):
def wavernn_vocode(
mel_specgram,
wavernn_checkpoint_name,
wavernn_loss,
wavernn_no_mulaw,
wavernn_no_batch_inference,
wavernn_batch_timesteps,
wavernn_batch_overlap,
device,
jit,
):
from torchaudio.models import wavernn
sys.path.append(os.path.join(os.path.dirname(__file__), "../pipeline_wavernn"))
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
wavernn_model = wavernn(wavernn_checkpoint_name).eval().to(device)
wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
......@@ -234,16 +198,26 @@ def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_
mel_specgram = transforms(mel_specgram.cpu())
with torch.no_grad():
waveform = wavernn_inference_model(mel_specgram.to(device),
loss_name=wavernn_loss,
mulaw=(not wavernn_no_mulaw),
batched=(not wavernn_no_batch_inference),
timesteps=wavernn_batch_timesteps,
overlap=wavernn_batch_overlap,)
waveform = wavernn_inference_model(
mel_specgram.to(device),
loss_name=wavernn_loss,
mulaw=(not wavernn_no_mulaw),
batched=(not wavernn_no_batch_inference),
timesteps=wavernn_batch_timesteps,
overlap=wavernn_batch_overlap,
)
return waveform.unsqueeze(0)
def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_fmax, jit, ):
def griffin_lim_vocode(
mel_specgram,
n_fft,
n_mels,
sample_rate,
mel_fmin,
mel_fmax,
jit,
):
from torchaudio.transforms import GriffinLim, InverseMelScale
inv_norm = InverseSpectralNormalization()
......@@ -254,7 +228,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f
f_min=mel_fmin,
f_max=mel_fmax,
mel_scale="slaney",
norm='slaney',
norm="slaney",
)
griffin_lim = GriffinLim(
n_fft=n_fft,
......@@ -263,11 +237,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f
win_length=1024,
)
vocoder = torch.nn.Sequential(
inv_norm,
inv_mel,
griffin_lim
)
vocoder = torch.nn.Sequential(inv_norm, inv_mel, griffin_lim)
if jit:
vocoder = torch.jit.script(vocoder)
......@@ -286,8 +256,7 @@ def main(args):
if args.checkpoint_path is None and args.checkpoint_name is None:
raise ValueError("Either --checkpoint-path or --checkpoint-name must be specified.")
elif args.checkpoint_path is not None and args.checkpoint_name is not None:
raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, "
"can only specify one.")
raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, " "can only specify one.")
n_symbols = len(get_symbol_list(args.text_preprocessor))
text_preprocessor = partial(
......@@ -301,21 +270,23 @@ def main(args):
if args.checkpoint_path is not None:
tacotron2 = Tacotron2(n_symbol=n_symbols)
tacotron2.load_state_dict(
unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)['state_dict']))
unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)["state_dict"])
)
tacotron2 = tacotron2.to(device).eval()
elif args.checkpoint_name is not None:
tacotron2 = pretrained_tacotron2(args.checkpoint_name).to(device).eval()
if n_symbols != tacotron2.n_symbols:
raise ValueError("the number of symbols for text_preprocessor ({n_symbols}) "
"should match the number of symbols for the"
"pretrained tacotron2 ({tacotron2.n_symbols}).")
raise ValueError(
"the number of symbols for text_preprocessor ({n_symbols}) "
"should match the number of symbols for the"
"pretrained tacotron2 ({tacotron2.n_symbols})."
)
if args.jit:
tacotron2 = torch.jit.script(tacotron2)
sequences, lengths = prepare_input_sequence([args.input_text],
text_processor=text_preprocessor)
sequences, lengths = prepare_input_sequence([args.input_text], text_processor=text_preprocessor)
sequences, lengths = sequences.long().to(device), lengths.long().to(device)
with torch.no_grad():
mel_specgram, _, _ = tacotron2.infer(sequences, lengths)
......@@ -324,24 +295,28 @@ def main(args):
waveform = nvidia_waveglow_vocode(mel_specgram=mel_specgram, device=device, jit=args.jit)
elif args.vocoder == "wavernn":
waveform = wavernn_vocode(mel_specgram=mel_specgram,
wavernn_checkpoint_name=args.wavernn_checkpoint_name,
wavernn_loss=args.wavernn_loss,
wavernn_no_mulaw=args.wavernn_no_mulaw,
wavernn_no_batch_inference=args.wavernn_no_batch_inference,
wavernn_batch_timesteps=args.wavernn_batch_timesteps,
wavernn_batch_overlap=args.wavernn_batch_overlap,
device=device,
jit=args.jit)
waveform = wavernn_vocode(
mel_specgram=mel_specgram,
wavernn_checkpoint_name=args.wavernn_checkpoint_name,
wavernn_loss=args.wavernn_loss,
wavernn_no_mulaw=args.wavernn_no_mulaw,
wavernn_no_batch_inference=args.wavernn_no_batch_inference,
wavernn_batch_timesteps=args.wavernn_batch_timesteps,
wavernn_batch_overlap=args.wavernn_batch_overlap,
device=device,
jit=args.jit,
)
elif args.vocoder == "griffin_lim":
waveform = griffin_lim_vocode(mel_specgram=mel_specgram,
n_fft=args.n_fft,
n_mels=args.n_mels,
sample_rate=args.sample_rate,
mel_fmin=args.mel_fmin,
mel_fmax=args.mel_fmax,
jit=args.jit)
waveform = griffin_lim_vocode(
mel_specgram=mel_specgram,
n_fft=args.n_fft,
n_mels=args.n_mels,
sample_rate=args.sample_rate,
mel_fmin=args.mel_fmin,
mel_fmax=args.mel_fmax,
jit=args.jit,
)
torchaudio.save(args.output_path, waveform, args.sample_rate)
......
......@@ -24,33 +24,34 @@
Modified from https://github.com/keithito/tacotron
"""
import inflect
import re
import inflect
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(text: str) -> str:
return re.sub(_comma_number_re, lambda m: m.group(1).replace(',', ''), text)
return re.sub(_comma_number_re, lambda m: m.group(1).replace(",", ""), text)
def _expand_pounds(text: str) -> str:
return re.sub(_pounds_re, r'\1 pounds', text)
return re.sub(_pounds_re, r"\1 pounds", text)
def _expand_dollars_repl_fn(m):
"""The replacement function for expanding dollars."""
match = m.group(1)
parts = match.split('.')
parts = match.split(".")
if len(parts) > 2:
return match + ' dollars' # Unexpected format
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
if len(parts) > 1 and parts[1]:
if len(parts[1]) == 1:
......@@ -61,17 +62,17 @@ def _expand_dollars_repl_fn(m):
else:
cents = 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return 'zero dollars'
return "zero dollars"
def _expand_dollars(text: str) -> str:
......@@ -79,7 +80,7 @@ def _expand_dollars(text: str) -> str:
def _expand_decimal_point(text: str) -> str:
return re.sub(_decimal_number_re, lambda m: m.group(1).replace('.', ' point '), text)
return re.sub(_decimal_number_re, lambda m: m.group(1).replace(".", " point "), text)
def _expand_ordinal(text: str) -> str:
......@@ -91,15 +92,15 @@ def _expand_number_repl_fn(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
return "two thousand"
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword='')
return _inflect.number_to_words(num, andword="")
def _expand_number(text: str) -> str:
......
......@@ -24,44 +24,47 @@
Modified from https://github.com/keithito/tacotron
"""
from typing import List, Union, Optional
import re
from typing import List, Union, Optional
from unidecode import unidecode
from torchaudio.datasets import CMUDict
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz'
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "abcdefghijklmnopqrstuvwxyz"
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None
......@@ -71,23 +74,25 @@ available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
def get_symbol_list(symbol_list: str = "english_characters",
cmudict_root: Optional[str] = "./") -> List[str]:
def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
if symbol_list == "english_characters":
return [_pad] + list(_special) + list(_punctuation) + list(_letters)
elif symbol_list == "english_phonemes":
return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
else:
raise ValueError(f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}.")
raise ValueError(
f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}."
)
def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
if phonemizer == "DeepPhonemizer":
from dp.phonemizer import Phonemizer
global _phonemizer
_other_symbols = ''.join(list(_special) + list(_punctuation))
_phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
_other_symbols = "".join(list(_special) + list(_punctuation))
_phone_symbols_re = r"(\[[A-Z]+?\]|" + "[" + _other_symbols + "])" # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
if _phonemizer is None:
# using a global variable so that we don't have to relode checkpoint
......@@ -97,7 +102,7 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
# Example:
# sent = "hello world!"
# '[HH][AH][L][OW] [W][ER][L][D]!'
sent = _phonemizer(sent, lang='en_us')
sent = _phonemizer(sent, lang="en_us")
# ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
ret = re.findall(_phone_symbols_re, sent)
......@@ -107,16 +112,19 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
return ret
else:
raise ValueError(f"The `phonemizer` {phonemizer} is not supported. "
"Supported `symbol_list` includes `'DeepPhonemizer'`.")
raise ValueError(
f"The `phonemizer` {phonemizer} is not supported. " "Supported `symbol_list` includes `'DeepPhonemizer'`."
)
def text_to_sequence(sent: str,
symbol_list: Union[str, List[str]] = "english_characters",
phonemizer: Optional[str] = "DeepPhonemizer",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
cmudict_root: Optional[str] = "./") -> List[int]:
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
def text_to_sequence(
sent: str,
symbol_list: Union[str, List[str]] = "english_characters",
phonemizer: Optional[str] = "DeepPhonemizer",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
cmudict_root: Optional[str] = "./",
) -> List[int]:
r"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
sent (str): The input sentence to convert to a sequence.
......@@ -138,19 +146,20 @@ def text_to_sequence(sent: str,
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
>>> text_to_sequence("hello world!", "english_phonemes")
[54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
'''
"""
if symbol_list == "english_phonemes":
if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
raise ValueError(
"When `symbol_list` is 'english_phonemes', "
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.")
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided."
)
sent = unidecode(sent) # convert to ascii
sent = sent.lower() # lower case
sent = normalize_numbers(sent) # expand numbers
for regex, replacement in _abbreviations: # expand abbreviations
sent = re.sub(regex, replacement, sent)
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace
sent = re.sub(_whitespace_re, " ", sent) # collapse whitespace
if isinstance(symbol_list, list):
symbols = symbol_list
......
This diff is collapsed.
......@@ -53,21 +53,19 @@ def pad_sequences(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
Modified from https://github.com/NVIDIA/DeepLearningExamples.
"""
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x) for x in batch]), dim=0, descending=True)
input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x) for x in batch]), dim=0, descending=True)
max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
return text_padded, input_lengths
def prepare_input_sequence(texts: List[str],
text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]:
def prepare_input_sequence(texts: List[str], text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]:
d = []
for text in texts:
d.append(torch.IntTensor(text_processor(text)[:]))
......
......@@ -51,7 +51,11 @@ class Processed(torch.utils.data.Dataset):
def split_process_librispeech(
datasets, transforms, language_model, root, folder_in_archive,
datasets,
transforms,
language_model,
root,
folder_in_archive,
):
def create(tags, cache=True):
......@@ -66,7 +70,10 @@ def split_process_librispeech(
[
Processed(
LIBRISPEECH(
root, tag, folder_in_archive=folder_in_archive, download=False,
root,
tag,
folder_in_archive=folder_in_archive,
download=False,
),
transform,
language_model.encode,
......
......@@ -7,16 +7,15 @@ from time import time
import torch
import torchaudio
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint
......@@ -80,20 +79,14 @@ def parse_args():
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
parser.add_argument(
"--reduce-lr-valid",
action="store_true",
help="reduce learning rate based on validation loss",
)
parser.add_argument(
"--normalize", action="store_true", help="normalize model input"
)
parser.add_argument(
"--progress-bar", action="store_true", help="use progress bar while training"
)
parser.add_argument("--normalize", action="store_true", help="normalize model input")
parser.add_argument("--progress-bar", action="store_true", help="use progress bar while training")
parser.add_argument(
"--decoder",
metavar="D",
......@@ -101,9 +94,7 @@ def parse_args():
choices=["greedy"],
help="decoder to use",
)
parser.add_argument(
"--batch-size", default=128, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument("--batch-size", default=128, type=int, metavar="N", help="mini-batch size")
parser.add_argument(
"--n-bins",
default=13,
......@@ -139,12 +130,8 @@ def parse_args():
metavar="GAMMA",
help="learning rate exponential decay constant",
)
parser.add_argument(
"--momentum", default=0.8, type=float, metavar="M", help="momentum"
)
parser.add_argument(
"--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay"
)
parser.add_argument("--momentum", default=0.8, type=float, metavar="M", help="momentum")
parser.add_argument("--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay")
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0)
......@@ -172,13 +159,9 @@ def parse_args():
type=str,
help="select which part of librispeech to validate with",
)
parser.add_argument(
"--distributed", action="store_true", help="enable DistributedDataParallel"
)
parser.add_argument("--distributed", action="store_true", help="enable DistributedDataParallel")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument(
"--world-size", type=int, default=8, help="the world size to initiate DPP"
)
parser.add_argument("--world-size", type=int, default=8, help="the world size to initiate DPP")
parser.add_argument("--jit", action="store_true", help="if used, model is jitted")
args = parser.parse_args()
......@@ -263,9 +246,7 @@ def train_one_epoch(
metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
start = time()
inputs = inputs.to(device, non_blocking=True)
......@@ -286,9 +267,7 @@ def train_one_epoch(
loss.backward()
if clip_grad > 0:
metric["gradient"] = torch.nn.utils.clip_grad_norm_(
model.parameters(), clip_grad
)
metric["gradient"] = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
optimizer.step()
......@@ -335,9 +314,7 @@ def evaluate(
metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
......@@ -351,9 +328,7 @@ def evaluate(
# input_lengths: batch size
# target_lengths: batch size
metric["cumulative loss"] += criterion(
outputs, targets, tensors_lengths, target_lengths
).item()
metric["cumulative loss"] += criterion(outputs, targets, tensors_lengths, target_lengths).item()
metric["dataset length"] += len(inputs)
metric["iteration"] += 1
......@@ -518,9 +493,7 @@ def main(rank, args):
else:
raise ValueError("Selected scheduler not supported")
criterion = torch.nn.CTCLoss(
blank=language_model.mapping[char_blank], zero_infinity=False
)
criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False)
# Data Loader
......@@ -569,9 +542,7 @@ def main(rank, args):
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
logging.info(
"Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
)
logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"])
else:
logging.info("Checkpoint: not found")
......@@ -649,9 +620,7 @@ def main(rank, args):
def spawn_main(main, args):
if args.distributed:
torch.multiprocessing.spawn(
main, args=(args,), nprocs=args.world_size, join=True
)
torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size, join=True)
else:
main(0, args)
......
import random
import torch
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
def __init__(self, dataset):
self.dataset = dataset
......@@ -47,16 +45,16 @@ class Processed(torch.utils.data.Dataset):
def split_process_dataset(args, transforms):
if args.dataset == 'ljspeech':
if args.dataset == "ljspeech":
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == 'libritts':
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)
elif args.dataset == "libritts":
train_dataset = LIBRITTS(root=args.file_path, url="train-clean-100", download=False)
val_dataset = LIBRITTS(root=args.file_path, url="dev-clean", download=False)
else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
......@@ -88,14 +86,8 @@ def collate_factory(args):
# random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]
waveform_combine = [
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1]
for i, x in enumerate(batch)
]
specgram = [
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length]
for i, x in enumerate(batch)
]
waveform_combine = [x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)]
specgram = [x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] for i, x in enumerate(batch)]
specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine)
......
......@@ -2,44 +2,46 @@ import argparse
import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram
from processing import NormalizeDB
from torchaudio.datasets import LJSPEECH
from torchaudio.models import wavernn
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS
from torchaudio.datasets import LJSPEECH
from torchaudio.transforms import MelSpectrogram
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-wav-path", default="./output.wav", type=str, metavar="PATH",
"--output-wav-path",
default="./output.wav",
type=str,
metavar="PATH",
help="The path to output the reconstructed wav file.",
)
parser.add_argument(
"--jit", default=False, action="store_true",
help="If used, the model and inference function is jitted."
)
parser.add_argument(
"--no-batch-inference", default=False, action="store_true",
help="Don't use batch inference."
"--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
)
parser.add_argument("--no-batch-inference", default=False, action="store_true", help="Don't use batch inference.")
parser.add_argument(
"--no-mulaw", default=False, action="store_true",
help="Don't use mulaw decoder to decoder the signal."
"--no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decoder the signal."
)
parser.add_argument(
"--checkpoint-name", default="wavernn_10k_epochs_8bits_ljspeech",
"--checkpoint-name",
default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(_MODEL_CONFIG_AND_URLS.keys()),
help="Select the WaveRNN checkpoint."
help="Select the WaveRNN checkpoint.",
)
parser.add_argument(
"--batch-timesteps", default=100, type=int,
"--batch-timesteps",
default=100,
type=int,
help="The time steps for each batch. Only used when batch inference is used",
)
parser.add_argument(
"--batch-overlap", default=5, type=int,
"--batch-overlap",
default=5,
type=int,
help="The overlapping time steps between batches. Only used when batch inference is used",
)
args = parser.parse_args()
......@@ -51,15 +53,15 @@ def main(args):
waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]
mel_kwargs = {
'sample_rate': sample_rate,
'n_fft': 2048,
'f_min': 40.,
'n_mels': 80,
'win_length': 1100,
'hop_length': 275,
'mel_scale': 'slaney',
'norm': 'slaney',
'power': 1,
"sample_rate": sample_rate,
"n_fft": 2048,
"f_min": 40.0,
"n_mels": 80,
"win_length": 1100,
"hop_length": 275,
"mel_scale": "slaney",
"norm": "slaney",
"power": 1,
}
transforms = torch.nn.Sequential(
MelSpectrogram(**mel_kwargs),
......@@ -74,11 +76,13 @@ def main(args):
wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad():
output = wavernn_inference_model(mel_specgram.to(device),
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,)
output = wavernn_inference_model(
mel_specgram.to(device),
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,
)
torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
......
......@@ -6,8 +6,7 @@ from torch.nn import functional as F
class LongCrossEntropyLoss(nn.Module):
r""" CrossEntropy loss
"""
r"""CrossEntropy loss"""
def __init__(self):
super(LongCrossEntropyLoss, self).__init__()
......@@ -21,7 +20,7 @@ class LongCrossEntropyLoss(nn.Module):
class MoLLoss(nn.Module):
r""" Discretized mixture of logistic distributions loss
r"""Discretized mixture of logistic distributions loss
Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
......@@ -57,10 +56,8 @@ class MoLLoss(nn.Module):
# unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix: 2 * nr_mix]
log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
)
means = y_hat[:, :, nr_mix : 2 * nr_mix]
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=self.log_scale_min)
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means)
......@@ -89,15 +86,11 @@ class MoLLoss(nn.Module):
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(
torch.clamp(cdf_delta, min=1e-12)
) + (1.0 - inner_inner_cond) * (
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = (
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
)
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
......@@ -110,8 +103,7 @@ class MoLLoss(nn.Module):
def _log_sum_exp(x):
r""" Numerically stable log_sum_exp implementation that prevents overflow
"""
r"""Numerically stable log_sum_exp implementation that prevents overflow"""
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
......
......@@ -8,14 +8,13 @@ from typing import List
import torch
import torchaudio
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint
......@@ -43,9 +42,7 @@ def parse_args():
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
parser.add_argument(
"--print-freq",
default=10,
......@@ -60,11 +57,13 @@ def parse_args():
type=str,
help="select dataset to train with",
)
parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate",
"--learning-rate",
default=1e-4,
type=float,
metavar="LR",
help="learning rate",
)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
parser.add_argument(
......@@ -73,9 +72,7 @@ def parse_args():
action="store_true",
help="if used, waveform is mulaw encoded",
)
parser.add_argument(
"--jit", default=False, action="store_true", help="if used, model is jitted"
)
parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted")
parser.add_argument(
"--upsample-scales",
default=[5, 5, 11],
......@@ -83,7 +80,10 @@ def parse_args():
help="the list of upsample scales",
)
parser.add_argument(
"--n-bits", default=8, type=int, help="the bits of output waveform",
"--n-bits",
default=8,
type=int,
help="the bits of output waveform",
)
parser.add_argument(
"--sample-rate",
......@@ -98,10 +98,16 @@ def parse_args():
help="the number of samples between the starts of consecutive frames",
)
parser.add_argument(
"--win-length", default=1100, type=int, help="the length of the STFT window",
"--win-length",
default=1100,
type=int,
help="the length of the STFT window",
)
parser.add_argument(
"--f-min", default=40.0, type=float, help="the minimum frequency",
"--f-min",
default=40.0,
type=float,
help="the minimum frequency",
)
parser.add_argument(
"--min-level-db",
......@@ -110,13 +116,22 @@ def parse_args():
help="the minimum db value for spectrogam normalization",
)
parser.add_argument(
"--n-res-block", default=10, type=int, help="the number of ResBlock in stack",
"--n-res-block",
default=10,
type=int,
help="the number of ResBlock in stack",
)
parser.add_argument(
"--n-rnn", default=512, type=int, help="the dimension of RNN layer",
"--n-rnn",
default=512,
type=int,
help="the dimension of RNN layer",
)
parser.add_argument(
"--n-fc", default=512, type=int, help="the dimension of fully connected layer",
"--n-fc",
default=512,
type=int,
help="the dimension of fully connected layer",
)
parser.add_argument(
"--kernel-size",
......@@ -125,7 +140,10 @@ def parse_args():
help="the number of kernel size in the first Conv1d layer",
)
parser.add_argument(
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use",
"--n-freq",
default=80,
type=int,
help="the number of spectrogram bins to use",
)
parser.add_argument(
"--n-hidden-melresnet",
......@@ -134,10 +152,16 @@ def parse_args():
help="the number of hidden dimensions of resblock in melresnet",
)
parser.add_argument(
"--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet",
"--n-output-melresnet",
default=128,
type=int,
help="the output dimension of melresnet",
)
parser.add_argument(
"--n-fft", default=2048, type=int, help="the number of Fourier bins",
"--n-fft",
default=2048,
type=int,
help="the number of Fourier bins",
)
parser.add_argument(
"--loss",
......@@ -159,10 +183,16 @@ def parse_args():
help="the ratio of waveforms for validation",
)
parser.add_argument(
"--file-path", default="", type=str, help="the path of audio files",
"--file-path",
default="",
type=str,
help="the path of audio files",
)
parser.add_argument(
"--normalization", default=True, action="store_true", help="if True, spectrogram is normalized",
"--normalization",
default=True,
action="store_true",
help="if True, spectrogram is normalized",
)
args = parser.parse_args()
......@@ -199,9 +229,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
loss.backward()
if args.clip_grad > 0:
gradient = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.clip_grad
)
gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
sums["gradient"] += gradient.item()
metric["gradient"] = gradient.item()
......@@ -271,8 +299,8 @@ def main(args):
sample_rate=args.sample_rate,
n_mels=args.n_freq,
f_min=args.f_min,
mel_scale='slaney',
norm='slaney',
mel_scale="slaney",
norm="slaney",
**melkwargs,
),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
......@@ -349,9 +377,7 @@ def main(args):
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(
f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}"
)
logging.info(f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}")
else:
logging.info("Checkpoint: not found")
......@@ -369,7 +395,12 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(
model, criterion, optimizer, train_loader, devices[0], epoch,
model,
criterion,
optimizer,
train_loader,
devices[0],
epoch,
)
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
......
......@@ -3,8 +3,7 @@ import torch.nn as nn
class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value
"""
r"""Normalize the spectrogram with a minimum db value"""
def __init__(self, min_level_db, normalization):
super().__init__()
......@@ -14,15 +13,12 @@ class NormalizeDB(nn.Module):
def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization:
return torch.clamp(
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
)
return torch.clamp((self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1)
return specgram
def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]
"""
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
assert abs(waveform).max() <= 1.0
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
......@@ -30,7 +26,6 @@ def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tens
def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]
"""
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]"""
return 2 * label / (2 ** bits - 1.0) - 1.0
......@@ -7,8 +7,7 @@ import torch
class MetricLogger:
r"""Logger for model metrics
"""
r"""Logger for model metrics"""
def __init__(self, group, print_freq=1):
self.print_freq = print_freq
......@@ -55,7 +54,6 @@ def save_checkpoint(state, is_best, filename):
def count_parameters(model):
r"""Count the total number of parameters in the model
"""
r"""Count the total number of parameters in the model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
......@@ -21,16 +21,15 @@
# *****************************************************************************
from torchaudio.models.wavernn import WaveRNN
import torch
import torchaudio
from torch import Tensor
from processing import normalized_waveform_to_bits
from torch import Tensor
from torchaudio.models.wavernn import WaveRNN
def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference.
r"""Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
x = [[h1, h2, ... hn]]
......@@ -47,7 +46,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
Return:
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
'''
"""
_, channels, total_len = x.size()
......@@ -74,7 +73,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array.
r"""Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
[seq2],
......@@ -93,7 +92,7 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
'''
"""
num_folds, channels, length = y.shape
timesteps = length - 2 * overlap
......@@ -130,17 +129,13 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def forward(self,
specgram: Tensor,
mulaw: bool = True,
batched: bool = True,
timesteps: int = 100,
overlap: int = 5) -> Tensor:
def forward(
self, specgram: Tensor, mulaw: bool = True, batched: bool = True, timesteps: int = 100, overlap: int = 5
) -> Tensor:
r"""Inference function for WaveRNN.
Based on the implementation from
......
from . import (
train,
trainer
)
from . import train, trainer
__all__ = ['train', 'trainer']
__all__ = ["train", "trainer"]
#!/usr/bin/env python3
"""Train Conv-TasNet"""
import time
import pathlib
import argparse
import pathlib
import time
import conv_tasnet
import torch
import torchaudio
import torchaudio.models
import conv_tasnet
from utils import dist_utils
from utils.dataset import utils as dataset_utils
......@@ -16,15 +15,14 @@ _LG = dist_utils.getLogger(__name__)
def _parse_args(args):
parser = argparse.ArgumentParser(description=__doc__,)
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug behavior. Each epoch will end with just one batch.")
group = parser.add_argument_group("Model Options")
group.add_argument(
"--num-speakers", required=True, type=int, help="The number of speakers."
"--debug", action="store_true", help="Enable debug behavior. Each epoch will end with just one batch."
)
group = parser.add_argument_group("Model Options")
group.add_argument("--num-speakers", required=True, type=int, help="The number of speakers.")
group = parser.add_argument_group("Dataset Options")
group.add_argument(
"--sample-rate",
......@@ -132,7 +130,8 @@ def _get_model(
_LG.info_on_master(" - X: %d", msk_num_layers)
_LG.info_on_master(" - R: %d", msk_num_stacks)
_LG.info_on_master(
" - Receptive Field: %s [samples]", model.mask_generator.receptive_field,
" - Receptive Field: %s [samples]",
model.mask_generator.receptive_field,
)
return model
......@@ -141,11 +140,9 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, task
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=4
)
train_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="train", sample_rate=sample_rate, duration=4)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test')
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="test")
train_loader = torch.utils.data.DataLoader(
train_dataset,
......@@ -173,8 +170,12 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_
def _write_header(log_path, args):
rows = [
[f"# torch: {torch.__version__}", ],
[f"# torchaudio: {torchaudio.__version__}", ]
[
f"# torch: {torch.__version__}",
],
[
f"# torchaudio: {torchaudio.__version__}",
],
]
rows.append(["# arguments"])
for key, item in vars(args).items():
......@@ -212,9 +213,7 @@ def train(args):
model = _get_model(num_sources=args.num_speakers)
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device] if torch.cuda.is_available() else None
)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device] if torch.cuda.is_available() else None)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
if args.resume:
......@@ -222,13 +221,9 @@ def train(args):
model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
else:
dist_utils.synchronize_params(
str(args.save_dir / "tmp.pt"), device, model, optimizer
)
dist_utils.synchronize_params(str(args.save_dir / "tmp.pt"), device, model, optimizer)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=3
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
......
import time
from typing import Tuple
from collections import namedtuple
from typing import Tuple
import torch
import torch.distributed as dist
from utils import dist_utils, metrics
_LG = dist_utils.getLogger(__name__)
Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = (
lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
)
Metric.__str__ = lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
def si_sdr_improvement(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).
......@@ -66,11 +60,7 @@ class OccasionalLogger:
def log(self, metric, progress, force=False):
now = time.monotonic()
if (
force
or now > self.last_time + self.time_interval
or progress > self.last_progress + self.progress_interval
):
if force or now > self.last_time + self.time_interval or progress > self.last_progress + self.progress_interval:
self.last_time = now
self.last_progress = progress
_LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)
......@@ -117,9 +107,7 @@ class Trainer:
loss = -si_snri
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip, norm_type=2.0
)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip, norm_type=2.0)
self.optimizer.step()
metric = Metric(si_snri.item(), sdri.item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment