Commit b5e4663a authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Add HiFi GAN Generator to prototypes (#2860)

Summary:
Part 1 of [T138011314](https://www.internalfb.com/intern/tasks/?t=138011314)

This PR ports the generator part of [HiFi GAN](https://arxiv.org/abs/2010.05646v2) from [the original implementation](https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75)

Adds tests:
- Smoke tests for architectures V1, V2, V3
- Check that output shapes are correct
- Check that the model is torchscriptable and scripting doesn't change the output
- Check that our code's output matches the original implementation. Here I clone the original repo inside `/tmp` and import necessary objects from inside the test function.  On test teardown I restore `PATH`, but don't remove the cloned code, so that it can be reused on subsequent runs - let me know if removing it would be a better practice

There are no quantization tests, because the model consists mainly of `Conv1d` and `ConvTransposed1d`, and they are [not supported by dynamic quantization](https://pytorch.org/docs/stable/quantization.html)

Pull Request resolved: https://github.com/pytorch/audio/pull/2860

Reviewed By: nateanl

Differential Revision: D41433416

Pulled By: sgrigory

fbshipit-source-id: f135c560df20f5138f01e3efdd182621edabb4f5
parent ba683bd1
...@@ -64,3 +64,30 @@ conformer_wav2vec2_pretrain_large ...@@ -64,3 +64,30 @@ conformer_wav2vec2_pretrain_large
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_large .. autofunction:: conformer_wav2vec2_pretrain_large
HiFiGANGenerator
~~~~~~~~~~~~~~~~
.. autoclass:: HiFiGANGenerator
.. automethod:: forward
hifigan_generator
~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator
hifigan_generator_v1
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v1
hifigan_generator_v2
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v2
hifigan_generator_v3
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v3
...@@ -464,6 +464,17 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop ...@@ -464,6 +464,17 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
year=2021, year=2021,
author={Guoguo Chen and Shuzhou Chai and Guanbo Wang and Jiayu Du and Wei-Qiang Zhang and Chao Weng and Dan Su and Daniel Povey and Jan Trmal and Junbo Zhang and Mingjie Jin and Sanjeev Khudanpur and Shinji Watanabe and Shuaijiang Zhao and Wei Zou and Xiangang Li and Xuchen Yao and Yongqing Wang and Yujun Wang and Zhao You and Zhiyong Yan} author={Guoguo Chen and Shuzhou Chai and Guanbo Wang and Jiayu Du and Wei-Qiang Zhang and Chao Weng and Dan Su and Daniel Povey and Jan Trmal and Junbo Zhang and Mingjie Jin and Sanjeev Khudanpur and Shinji Watanabe and Shuaijiang Zhao and Wei Zou and Xiangang Li and Xuchen Yao and Yongqing Wang and Yujun Wang and Zhao You and Zhiyong Yan}
} }
@inproceedings{NEURIPS2020_c5d73680,
author = {Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
pages = {17022--17033},
publisher = {Curran Associates, Inc.},
title = {HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
url = {https://proceedings.neurips.cc/paper/2020/file/c5d736809766d46260d816d8dbc9eb44-Paper.pdf},
volume = {33},
year = {2020}
}
@inproceedings{ko15_interspeech, @inproceedings{ko15_interspeech,
author={Tom Ko and Vijayaditya Peddinti and Daniel Povey and Sanjeev Khudanpur}, author={Tom Ko and Vijayaditya Peddinti and Daniel Povey and Sanjeev Khudanpur},
title={{Audio augmentation for speech recognition}}, title={{Audio augmentation for speech recognition}},
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .hifi_gan_test_impl import HiFiGANTestImpl
class HiFiGANFloat32CPUTest(HiFiGANTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class HiFiGANFloat64CPUTest(HiFiGANTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .hifi_gan_test_impl import HiFiGANTestImpl
@skipIfNoCuda
class HiFiGANFloat32CPUTest(HiFiGANTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class HiFiGANFloat64CPUTest(HiFiGANTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import importlib
import os
import subprocess
import sys
import torch
from parameterized import parameterized
from torchaudio.prototype.models import (
hifigan_generator,
hifigan_generator_v1,
hifigan_generator_v2,
hifigan_generator_v3,
)
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class HiFiGANTestImpl(TestBaseMixin):
def _get_model_config(self):
return {
"upsample_rates": (8, 8, 4),
"upsample_kernel_sizes": (16, 16, 8),
"upsample_initial_channel": 256,
"resblock_kernel_sizes": (3, 5, 7),
"resblock_dilation_sizes": ((1, 2), (2, 6), (3, 12)),
"resblock_type": 2,
"in_channels": 80,
"lrelu_slope": 0.1,
}
def _get_input_config(self):
model_config = self._get_model_config()
return {
"batch_size": 7,
"in_channels": model_config["in_channels"],
"time_length": 10,
}
def _get_model(self):
return hifigan_generator(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
def _get_inputs(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
time_length = input_config["time_length"]
in_channels = input_config["in_channels"]
input = torch.rand(batch_size, in_channels, time_length).to(device=self.device, dtype=self.dtype)
return input
def _import_original_impl(self):
"""Clone the original implmentation of HiFi GAN and import necessary objects. Used in a test below checking
that output of our implementation matches the original one.
"""
module_name = "hifigan_cloned"
path_cloned = "/tmp/" + module_name
if not os.path.isdir(path_cloned):
subprocess.run(["git", "clone", "https://github.com/jik876/hifi-gan.git", path_cloned])
subprocess.run(["git", "checkout", "4769534d45265d52a904b850da5a622601885777"], cwd=path_cloned)
# Make sure imports work in the cloned code. Module "utils" is imported inside "models.py" in the cloned code,
# so we need to delete "utils" from the modules cache - a module with this name is already imported by another
# test
sys.path.insert(0, "/tmp")
sys.path.insert(0, path_cloned)
if "utils" in sys.modules:
del sys.modules["utils"]
env = importlib.import_module(module_name + ".env")
models = importlib.import_module(module_name + ".models")
return env.AttrDict, models.Generator
def setUp(self):
super().setUp()
torch.random.manual_seed(31)
# Import code necessary for test_original_implementation_match
self.AttrDict, self.Generator = self._import_original_impl()
def tearDown(self):
# PATH was modified on test setup, revert the modifications
sys.path.pop(0)
sys.path.pop(0)
@parameterized.expand([(hifigan_generator_v1,), (hifigan_generator_v2,), (hifigan_generator_v3,)])
def test_smoke(self, factory_func):
r"""Verify that model architectures V1, V2, V3 can be constructed and applied on inputs"""
model = factory_func().to(device=self.device, dtype=self.dtype)
input = self._get_inputs()
model(input)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting the model does not change the behavior of method `forward`."""
inputs = self._get_inputs()
original_model = self._get_model()
scripted_model = torch_script(original_model).eval()
for _ in range(2):
ref_out = original_model(inputs)
scripted_out = scripted_model(inputs)
self.assertEqual(ref_out, scripted_out)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_config = self._get_input_config()
model_config = self._get_model_config()
batch_size = input_config["batch_size"]
time_length = input_config["time_length"]
inputs = self._get_inputs()
model = self._get_model()
total_upsample_rate = 1 # Use loop instead of math.prod for compatibility with Python 3.7
for upsample_rate in model_config["upsample_rates"]:
total_upsample_rate *= upsample_rate
for _ in range(2):
out = model(inputs)
self.assertEqual(
(batch_size, 1, total_upsample_rate * time_length),
out.shape,
)
def test_original_implementation_match(self):
r"""Check that output of our implementation matches the original one."""
model_config = self._get_model_config()
model_config = self.AttrDict(model_config)
model_config.resblock = "1" if model_config.resblock_type == 1 else "2"
model_ref = self.Generator(model_config).to(device=self.device, dtype=self.dtype)
model_ref.remove_weight_norm()
inputs = self._get_inputs()
model = self._get_model()
model.load_state_dict(model_ref.state_dict())
ref_output = model_ref(inputs)
output = model(inputs)
self.assertEqual(ref_output, output)
...@@ -8,6 +8,13 @@ from ._conformer_wav2vec2 import ( ...@@ -8,6 +8,13 @@ from ._conformer_wav2vec2 import (
) )
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .hifi_gan import (
hifigan_generator,
hifigan_generator_v1,
hifigan_generator_v2,
hifigan_generator_v3,
HiFiGANGenerator,
)
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
...@@ -22,4 +29,9 @@ __all__ = [ ...@@ -22,4 +29,9 @@ __all__ = [
"ConformerWav2Vec2PretrainModel", "ConformerWav2Vec2PretrainModel",
"emformer_hubert_base", "emformer_hubert_base",
"emformer_hubert_model", "emformer_hubert_model",
"HiFiGANGenerator",
"hifigan_generator_v1",
"hifigan_generator_v2",
"hifigan_generator_v3",
"hifigan_generator",
] ]
"""
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
class HiFiGANGenerator(torch.nn.Module):
"""Generator part of *HiFi GAN* :cite:`NEURIPS2020_c5d73680`.
Source: https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75
Note:
To build the model, please use one of the factory functions: :py:func:`hifigan_generator`,
:py:func:`hifigan_generator_v1`, :py:func:`hifigan_generator_v2`, :py:func:`hifigan_generator_v3`.
Args:
in_channels (int): Number of channels in the input features.
upsample_rates (tuple of ``int``): Factors by which each upsampling layer increases the time dimension.
upsample_initial_channel (int): Number of channels in the input feature tensor.
upsample_kernel_sizes (tuple of ``int``): Kernel size for each upsampling layer.
resblock_kernel_sizes (tuple of ``int``): Kernel size for each residual block.
resblock_dilation_sizes (tuple of tuples of ``int``): Dilation sizes for each 1D convolutional layer in each
residual block. For resblock type 1 inner tuples should have length 3, because there are 3
convolutions in each layer. For resblock type 2 they should have length 2.
resblock_type (int, 1 or 2): Determines whether ``ResBlock1`` or ``ResBlock2`` will be used.
lrelu_slope (float): Slope of leaky ReLUs in activations.
"""
def __init__(
self,
in_channels: int,
upsample_rates: Tuple[int, ...],
upsample_initial_channel: int,
upsample_kernel_sizes: Tuple[int, ...],
resblock_kernel_sizes: Tuple[int, ...],
resblock_dilation_sizes: Tuple[Tuple[int, ...], ...],
resblock_type: int,
lrelu_slope: float,
):
super(HiFiGANGenerator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock = ResBlock1 if resblock_type == 1 else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for (k, d) in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(resblock(ch, k, d, lrelu_slope))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3)
self.lrelu_slope = lrelu_slope
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Feature input tensor of shape `(batch_size, num_channels, time_length)`.
Returns:
Tensor of shape `(batch_size, 1, time_length * upsample_rate)`, where `upsample_rate` is the product
of upsample rates for all layers.
"""
x = self.conv_pre(x)
for i, upsampling_layer in enumerate(self.ups):
x = F.leaky_relu(x, self.lrelu_slope)
x = upsampling_layer(x)
xs = torch.zeros_like(x)
for j in range(self.num_kernels):
res_block: ResBlockInterface = self.resblocks[i * self.num_kernels + j]
xs += res_block.forward(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
@torch.jit.interface
class ResBlockInterface(torch.nn.Module):
"""Interface for ResBlock - necessary to make type annotations in ``HiFiGANGenerator.forward`` compatible
with TorchScript
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class ResBlock1(torch.nn.Module):
"""Residual block of type 1 for HiFiGAN Generator :cite:`NEURIPS2020_c5d73680`.
Args:
channels (int): Number of channels in the input features.
kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
dilation (tuple of 3 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3, 5)``)
lrelu_slope (float): Slope of leaky ReLUs in activations.
"""
def __init__(
self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1
):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
]
)
self.lrelu_slope = lrelu_slope
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): input of shape ``(batch_size, channels, time_length)``.
Returns:
Tensor of the same shape as input.
"""
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, self.lrelu_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, self.lrelu_slope)
xt = conv2(xt)
x = xt + x
return x
class ResBlock2(torch.nn.Module):
"""Residual block of type 2 for HiFiGAN Generator :cite:`NEURIPS2020_c5d73680`.
Args:
channels (int): Number of channels in the input features.
kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
dilation (tuple of 2 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3)``)
lrelu_slope (float): Slope of leaky ReLUs in activations.
"""
def __init__(
self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3), lrelu_slope: float = 0.1
):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
]
)
self.lrelu_slope = lrelu_slope
def forward(self, x: torch.Tensor):
"""
Args:
x (Tensor): input of shape ``(batch_size, channels, time_length)``.
Returns:
Tensor of the same shape as input.
"""
for c in self.convs:
xt = F.leaky_relu(x, self.lrelu_slope)
xt = c(xt)
x = xt + x
return x
def get_padding(kernel_size, dilation=1):
"""Find padding for which 1D convolution preserves the input shape."""
return int((kernel_size * dilation - dilation) / 2)
def hifigan_generator(
in_channels: int,
upsample_rates: Tuple[int, ...],
upsample_initial_channel: int,
upsample_kernel_sizes: Tuple[int, ...],
resblock_kernel_sizes: Tuple[int, ...],
resblock_dilation_sizes: Tuple[Tuple[int, ...], ...],
resblock_type: int,
lrelu_slope: float,
) -> HiFiGANGenerator:
r"""Builds HiFi GAN Generator :cite:`NEURIPS2020_c5d73680`.
Args:
in_channels (int): See :py:class:`HiFiGANGenerator`.
upsample_rates (tuple of ``int``): See :py:class:`HiFiGANGenerator`.
upsample_initial_channel (int): See :py:class:`HiFiGANGenerator`.
upsample_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANGenerator`.
resblock_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANGenerator`.
resblock_dilation_sizes (tuple of tuples of ``int``): See :py:class:`HiFiGANGenerator`.
resblock_type (int, 1 or 2): See :py:class:`HiFiGANGenerator`.
Returns:
HiFiGANGenerator: generated model.
"""
return HiFiGANGenerator(
upsample_rates=upsample_rates,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes,
resblock_type=resblock_type,
upsample_initial_channel=upsample_initial_channel,
upsample_kernel_sizes=upsample_kernel_sizes,
in_channels=in_channels,
lrelu_slope=lrelu_slope,
)
def hifigan_generator_v1() -> HiFiGANGenerator:
r"""Builds HiFiGAN Generator with V1 architecture :cite:`NEURIPS2020_c5d73680`.
Returns:
HiFiGANGenerator: generated model.
"""
return hifigan_generator(
upsample_rates=(8, 8, 2, 2),
upsample_kernel_sizes=(16, 16, 4, 4),
upsample_initial_channel=512,
resblock_kernel_sizes=(3, 7, 11),
resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
resblock_type=1,
in_channels=80,
lrelu_slope=0.1,
)
def hifigan_generator_v2() -> HiFiGANGenerator:
r"""Builds HiFiGAN Generator with V2 architecture :cite:`NEURIPS2020_c5d73680`.
Returns:
HiFiGANGenerator: generated model.
"""
return hifigan_generator(
upsample_rates=(8, 8, 2, 2),
upsample_kernel_sizes=(16, 16, 4, 4),
upsample_initial_channel=128,
resblock_kernel_sizes=(3, 7, 11),
resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
resblock_type=1,
in_channels=80,
lrelu_slope=0.1,
)
def hifigan_generator_v3() -> HiFiGANGenerator:
r"""Builds HiFiGAN Generator with V3 architecture :cite:`NEURIPS2020_c5d73680`.
Returns:
HiFiGANGenerator: generated model.
"""
return hifigan_generator(
upsample_rates=(8, 8, 4),
upsample_kernel_sizes=(16, 16, 8),
upsample_initial_channel=256,
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12)),
resblock_type=2,
in_channels=80,
lrelu_slope=0.1,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment