"src/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "a7c5007c238830238f68aa88bc37cc5e424fa82b"
Commit 54e5c859 authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Add HiFiGAN bundle (#2921)

Summary:
Closes [T138011314](https://www.internalfb.com/intern/tasks/?t=138011314)
## Description
- Add  bundle `HIFIGAN_GENERATOR_V3_LJSPEECH` to prototypes. The bundle contains pre-trained HiFiGAN generator weights from the [original HiFiGAN publication](https://github.com/jik876/hifi-gan#pretrained-model), converted slightly to fit our model
- Add tests
  - unit tests checking that vocoder and mel-transform implementations in the bundle give the same results as the original ones. Part of the original HiFiGAN code is ported to this repo to enable these tests
  - integration test checking that waveform reconstructed from mel spectrogram by the bundle is close enough to the original
- Add docs

Pull Request resolved: https://github.com/pytorch/audio/pull/2921

Reviewed By: nateanl, mthrok

Differential Revision: D42034761

Pulled By: sgrigory

fbshipit-source-id: 8b0dadeed510b3c9371d6aa2c46ec7d8378f6048
parent bf085b1f
...@@ -24,3 +24,29 @@ EMFORMER_RNNT_BASE_TEDLIUM3 ...@@ -24,3 +24,29 @@ EMFORMER_RNNT_BASE_TEDLIUM3
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3 .. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value: :no-value:
HiFiGAN Vocoder
---------------
Interface
~~~~~~~~~
:py:class:`HiFiGANVocoderBundle` defines HiFiGAN Vocoder pipeline capable of transforming mel spectrograms into waveforms.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
HiFiGANVocoderBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
HIFIGAN_VOCODER_V3_LJSPEECH
import math
import torch
import torchaudio
from torchaudio.prototype.functional import oscillator_bank
from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH
def test_hifi_gan_pretrained_weights():
"""Test that a waveform reconstructed from mel spectrogram by HiFiGAN bundle is close enough to the original.
The main transformations performed in this test can be represented as
- audio -> reference log mel spectrogram
- audio -> mel spectrogram -> audio -> estimated log mel spectrogram
In the end, we compare estimated log mel spectrogram to the reference one. See comments in code for details.
"""
bundle = HIFIGAN_VOCODER_V3_LJSPEECH
# Get HiFiGAN-compatible transformation from waveform to mel spectrogram
mel_transform = bundle.get_mel_transform()
# Get HiFiGAN vocoder
vocoder = bundle.get_vocoder()
# Create a synthetic waveform
ref_waveform = get_sin_sweep(sample_rate=bundle.sample_rate, length=100000)
ref_waveform = ref_waveform[:, : -(ref_waveform.shape[1] % mel_transform.hop_size)]
# Generate mel spectrogram from waveform
mel_spectrogram = mel_transform(ref_waveform)
with torch.no_grad():
# Generate waveform from mel spectrogram
estimated_waveform = vocoder(mel_spectrogram).squeeze(0)
# Measure the reconstruction error.
# Even though the reconstructed audio is perceptually very close to the original, it doesn't score well on
# metrics like Si-SNR. It might be that HiFiGAN introduces non-uniform shifts to the reconstructed waveforms.
# So to evaluate the recontruction error we compute mel spectrograms of the reference and recontructed waveforms,
# and compare relative mean squared error of their logarithms.
final_spec = torchaudio.transforms.MelSpectrogram(sample_rate=bundle.sample_rate, normalized=True)
# Log mel spectrogram of the estimated waveform
estimated_spectorogram = final_spec(estimated_waveform)
estimated_spectorogram = torch.log(torch.clamp(estimated_spectorogram, min=1e-5))
# Log mel spectrogram of the reference waveform
ref_spectrogram = final_spec(ref_waveform)
ref_spectrogram = torch.log(torch.clamp(ref_spectrogram, min=1e-5))
# Check that relative MSE is below 4%
mse = ((estimated_spectorogram - ref_spectrogram) ** 2).mean()
mean_ref = ((ref_spectrogram) ** 2).mean()
print(mse / mean_ref)
assert mse / mean_ref < 0.04
def get_sin_sweep(sample_rate, length):
"""Create a waveform which changes frequency from 0 to the Nyquist frequency (half of the sample rate)"""
nyquist_freq = sample_rate / 2
freq = torch.logspace(0, math.log(0.99 * nyquist_freq, 10), length).unsqueeze(-1)
amp = torch.ones((length, 1))
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
return waveform.unsqueeze(0)
import importlib
import os
import subprocess
import sys
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import ( from torchaudio.prototype.models import (
...@@ -11,8 +6,13 @@ from torchaudio.prototype.models import ( ...@@ -11,8 +6,13 @@ from torchaudio.prototype.models import (
hifigan_generator_v2, hifigan_generator_v2,
hifigan_generator_v3, hifigan_generator_v3,
) )
from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from .original.env import AttrDict
from .original.meldataset import mel_spectrogram as ref_mel_spectrogram
from .original.models import Generator
class HiFiGANTestImpl(TestBaseMixin): class HiFiGANTestImpl(TestBaseMixin):
def _get_model_config(self): def _get_model_config(self):
...@@ -47,36 +47,9 @@ class HiFiGANTestImpl(TestBaseMixin): ...@@ -47,36 +47,9 @@ class HiFiGANTestImpl(TestBaseMixin):
input = torch.rand(batch_size, in_channels, time_length).to(device=self.device, dtype=self.dtype) input = torch.rand(batch_size, in_channels, time_length).to(device=self.device, dtype=self.dtype)
return input return input
def _import_original_impl(self):
"""Clone the original implmentation of HiFi GAN and import necessary objects. Used in a test below checking
that output of our implementation matches the original one.
"""
module_name = "hifigan_cloned"
path_cloned = "/tmp/" + module_name
if not os.path.isdir(path_cloned):
subprocess.run(["git", "clone", "https://github.com/jik876/hifi-gan.git", path_cloned])
subprocess.run(["git", "checkout", "4769534d45265d52a904b850da5a622601885777"], cwd=path_cloned)
# Make sure imports work in the cloned code. Module "utils" is imported inside "models.py" in the cloned code,
# so we need to delete "utils" from the modules cache - a module with this name is already imported by another
# test
sys.path.insert(0, "/tmp")
sys.path.insert(0, path_cloned)
if "utils" in sys.modules:
del sys.modules["utils"]
env = importlib.import_module(module_name + ".env")
models = importlib.import_module(module_name + ".models")
return env.AttrDict, models.Generator
def setUp(self): def setUp(self):
super().setUp() super().setUp()
torch.random.manual_seed(31) torch.random.manual_seed(31)
# Import code necessary for test_original_implementation_match
self.AttrDict, self.Generator = self._import_original_impl()
def tearDown(self):
# PATH was modified on test setup, revert the modifications
sys.path.pop(0)
sys.path.pop(0)
@parameterized.expand([(hifigan_generator_v1,), (hifigan_generator_v2,), (hifigan_generator_v3,)]) @parameterized.expand([(hifigan_generator_v1,), (hifigan_generator_v2,), (hifigan_generator_v3,)])
def test_smoke(self, factory_func): def test_smoke(self, factory_func):
...@@ -122,9 +95,9 @@ class HiFiGANTestImpl(TestBaseMixin): ...@@ -122,9 +95,9 @@ class HiFiGANTestImpl(TestBaseMixin):
def test_original_implementation_match(self): def test_original_implementation_match(self):
r"""Check that output of our implementation matches the original one.""" r"""Check that output of our implementation matches the original one."""
model_config = self._get_model_config() model_config = self._get_model_config()
model_config = self.AttrDict(model_config) model_config = AttrDict(model_config)
model_config.resblock = "1" if model_config.resblock_type == 1 else "2" model_config.resblock = "1" if model_config.resblock_type == 1 else "2"
model_ref = self.Generator(model_config).to(device=self.device, dtype=self.dtype) model_ref = Generator(model_config).to(device=self.device, dtype=self.dtype)
model_ref.remove_weight_norm() model_ref.remove_weight_norm()
inputs = self._get_inputs() inputs = self._get_inputs()
...@@ -134,3 +107,27 @@ class HiFiGANTestImpl(TestBaseMixin): ...@@ -134,3 +107,27 @@ class HiFiGANTestImpl(TestBaseMixin):
ref_output = model_ref(inputs) ref_output = model_ref(inputs)
output = model(inputs) output = model(inputs)
self.assertEqual(ref_output, output) self.assertEqual(ref_output, output)
def test_mel_transform(self):
"""Check that HIFIGAN_VOCODER_V3_LJSPEECH.get_mel_transform generates the same mel spectrogram as the original
HiFiGAN implementation when applied on a synthetic waveform.
There seems to be no way to change dtype in the original implmentation, so we feed in the waveform with the
default dtype and cast the output before comparison.
"""
synth_waveform = torch.rand(1, 1000).to(device=self.device)
# Get HiFiGAN-compatible transformation from waveform to mel spectrogram
self.mel_spectrogram = HIFIGAN_VOCODER_V3_LJSPEECH.get_mel_transform().to(dtype=self.dtype, device=self.device)
mel_spec = self.mel_spectrogram(synth_waveform.to(dtype=self.dtype))
# Generate mel spectrogram with original implementation
ref_mel_spec = ref_mel_spectrogram(
synth_waveform,
n_fft=self.mel_spectrogram.n_fft,
num_mels=self.mel_spectrogram.n_mels,
sampling_rate=self.mel_spectrogram.sample_rate,
hop_size=self.mel_spectrogram.hop_size,
win_size=self.mel_spectrogram.win_length,
fmin=self.mel_spectrogram.f_min,
fmax=self.mel_spectrogram.f_max,
)
self.assertEqual(ref_mel_spec.to(dtype=self.dtype), mel_spec, atol=1e-5, rtol=1e-5)
# Reference Implementation of HiFiGAN
The code in this folder was taken from the original implementation
https://github.com/jik876/hifi-gan/tree/4769534d45265d52a904b850da5a622601885777
which was made available the following liscence:
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This code is used for testing that our implementation matches the original one. To enable such testing the
ported code has been are modified in a minimal way, namely:
- Remove objects other than `mel_spectrogram` and its dependencies from `meldataset.py`
- Remove objects other than `AttrDict` from `env.py`
- Remove objects other than `init_weights` and `get_padding` from `utils.py`
- Add `return_complex=False` argument to `torch.stft` call in `mel_spectrogram` in `meldataset.py`, to make code
PyTorch 2.0 compatible
- Remove the import statements required only for the removed functions.
- Format the code to pass pre-commit checks (see `.pre-commit-config.yaml` for configuration).
Apart from the changes listed above, the implementation of the retained functions and classes is kept as-is.
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from .utils import get_padding, init_weights
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
__all__ = [ __all__ = [
"EMFORMER_RNNT_BASE_MUSTC", "EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3", "EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle",
] ]
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch.nn import Module
from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models.hifi_gan import hifigan_generator, HiFiGANGenerator
from torchaudio.transforms import MelSpectrogram
@dataclass
class HiFiGANVocoderBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.HiFiGANGenerator`.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
This bundle can convert mel spectrorgam to waveforms and vice versa. A typical use case would be a flow like
`text -> mel spectrogram -> waveform`, where one can use an external component, e.g. Tacotron2,
to generate mel spectrogram from text. Please see below for the code example.
Example: Transform synthetic mel spectrogram to audio.
>>> import torch
>>> import torchaudio
>>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle
>>>
>>> # Load the HiFiGAN bundle
>>> vocoder = bundle.get_vocoder()
Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_generator_v3_ljspeech.pth"
100%|████████████| 5.59M/5.59M [00:00<00:00, 18.7MB/s]
>>>
>>> # Generate synthetic mel spectrogram
>>> specgram = torch.sin(0.5 * torch.arange(start=0, end=100)).expand(bundle._vocoder_params["in_channels"], 100)
>>>
>>> # Transform mel spectrogram into audio
>>> waveform = vocoder(specgram)
>>> torchaudio.save('sample.wav', waveform, bundle.sample_rate)
Example: Usage together with Tacotron2, text to audio.
>>> import torch
>>> import torchaudio
>>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle_hifigan
>>>
>>> # Load Tacotron2 bundle
>>> bundle_tactron2 = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
>>> processor = bundle_tactron2.get_text_processor()
>>> tacotron2 = bundle_tactron2.get_tacotron2()
>>>
>>> # Use Tacotron2 to convert text to mel spectrogram
>>> text = "A quick brown fox jumped over a lazy dog"
>>> input, lengths = processor(text)
>>> specgram, lengths, _ = tacotron2.infer(input, lengths)
>>>
>>> # Load HiFiGAN bundle
>>> vocoder = bundle_hifigan.get_vocoder()
Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_generator_v3_ljspeech.pth"
100%|████████████| 5.59M/5.59M [00:03<00:00, 1.55MB/s]
>>>
>>> # Use HiFiGAN to convert mel spectrogram to audio
>>> waveform = vocoder(specgram).squeeze(0)
>>> torchaudio.save('sample.wav', waveform, bundle_hifigan.sample_rate)
""" # noqa: E501
_path: str
_vocoder_params: Dict[str, Any] # Vocoder parameters
_mel_params: Dict[str, Any] # Mel transformation parameters
_sample_rate: float
def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_vocoder(self, *, dl_kwargs=None) -> HiFiGANGenerator:
"""Construct the HiFiGAN Generator model, which can be used a vocoder, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.prototype.models.HiFiGANGenerator`.
"""
model = hifigan_generator(**self._vocoder_params)
model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval()
return model
def get_mel_transform(self) -> Module:
"""Construct an object which transforms waveforms into mel spectrograms."""
return _HiFiGANMelSpectrogram(
n_mels=self._vocoder_params["in_channels"],
sample_rate=self._sample_rate,
**self._mel_params,
)
@property
def sample_rate(self):
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
class _HiFiGANMelSpectrogram(torch.nn.Module):
"""
Generate mel spectrogram in a way equivalent to the original HiFiGAN implementation:
https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72
This class wraps around :py:class:`torchaudio.transforms.MelSpectrogram`, but performs extra steps to achive
equivalence with the HiFiGAN implementation.
Args:
hop_size (int): Length of hop between STFT windows.
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins.
win_length (int): Window size.
f_min (float or None): Minimum frequency.
f_max (float or None): Maximum frequency.
sample_rate (int): Sample rate of audio signal.
n_mels (int): Number of mel filterbanks.
"""
def __init__(
self,
hop_size: int,
n_fft: int,
win_length: int,
f_min: Optional[float],
f_max: Optional[float],
sample_rate: float,
n_mels: int,
):
super(_HiFiGANMelSpectrogram, self).__init__()
self.mel_transform = MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_size,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
normalized=False,
pad=0,
mel_scale="slaney",
norm="slaney",
center=False,
)
self.sample_rate = sample_rate
self.hop_size = hop_size
self.n_fft = n_fft
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
self.n_mels = n_mels
self.pad_size = int((n_fft - hop_size) / 2)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""Generate mel spectrogram from a waveform. Should have same sample rate as ``self.sample_rate``.
Args:
waveform (Tensor): waveform of shape ``(batch_size, time_length)``.
Returns:
Tensor of shape ``(batch_size, n_mel, time_length)``
"""
ref_waveform = F.pad(waveform.unsqueeze(1), (self.pad_size, self.pad_size), mode="reflect")
ref_waveform = ref_waveform.squeeze(1)
spectr = (self.mel_transform.spectrogram(ref_waveform) + 1e-9) ** 0.5
mel_spectrogram = self.mel_transform.mel_scale(spectr)
mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5))
return mel_spectrogram
HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle(
"hifigan_generator_v3_ljspeech.pth",
_vocoder_params={
"upsample_rates": (8, 8, 4),
"upsample_kernel_sizes": (16, 16, 8),
"upsample_initial_channel": 256,
"resblock_kernel_sizes": (3, 5, 7),
"resblock_dilation_sizes": ((1, 2), (2, 6), (3, 12)),
"resblock_type": 2,
"in_channels": 80,
"lrelu_slope": 0.1,
},
_mel_params={
"hop_size": 256,
"n_fft": 1024,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
},
_sample_rate=22050,
)
HIFIGAN_VOCODER_V3_LJSPEECH.__doc__ = """HiFiGAN Vocoder pipeline, trained on *The LJ Speech Dataset*
:cite:`ljspeech17`.
This pipeine can be used with an external component which generates mel spectrograms from text, for example,
Tacotron2 - see examples in :py:class:`HiFiGANVocoderBundle`.
Although this works with the existing Tacotron2 bundles, for the best results one needs to retrain Tacotron2
using the same data preprocessing pipeline which was used for training HiFiGAN. In particular, the original
HiFiGAN implementation uses a custom method of generating mel spectrograms from waveforms, different from
:py:class:`torchaudio.transforms.MelSpectrogram`. We reimplemented this transform as
:py:meth:`HiFiGANVocoderBundle.get_mel_transform`, making sure it is equivalent to the original HiFiGAN code `here
<https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72>`_.
The underlying vocoder is constructed by
:py:func:`torchaudio.prototype.models.hifigan_generator`. The weights are converted from the ones published
with the original paper :cite:`NEURIPS2020_c5d73680` under `MIT License
<https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/LICENSE>`__. See links to
pre-trained models on `GitHub <https://github.com/jik876/hifi-gan#pretrained-model>`__.
Please refer to :py:class:`HiFiGANVocoderBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment