Commit 5e75c8e8 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Rename generator to vocoder in HiFiGAN model and factory functions (#2955)

Summary:
The generator part of HiFiGAN model is a vocoder which converts mel spectrogram to waveform. It makes more sense to name it as vocoder for better understanding.

Pull Request resolved: https://github.com/pytorch/audio/pull/2955

Reviewed By: carolineechen

Differential Revision: D42348864

Pulled By: nateanl

fbshipit-source-id: c45a2f8d8d205ee381178ae5d37e9790a257e1aa
parent 5428e283
...@@ -65,29 +65,29 @@ conformer_wav2vec2_pretrain_large ...@@ -65,29 +65,29 @@ conformer_wav2vec2_pretrain_large
.. autofunction:: conformer_wav2vec2_pretrain_large .. autofunction:: conformer_wav2vec2_pretrain_large
HiFiGANGenerator HiFiGANVocoder
~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
.. autoclass:: HiFiGANGenerator .. autoclass:: HiFiGANVocoder
.. automethod:: forward .. automethod:: forward
hifigan_generator hifigan_vocoder
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator .. autofunction:: hifigan_vocoder
hifigan_generator_v1 hifigan_vocoder_v1
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v1 .. autofunction:: hifigan_vocoder_v1
hifigan_generator_v2 hifigan_vocoder_v2
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v2 .. autofunction:: hifigan_vocoder_v2
hifigan_generator_v3 hifigan_vocoder_v3
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
.. autofunction:: hifigan_generator_v3 .. autofunction:: hifigan_vocoder_v3
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import ( from torchaudio.prototype.models import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3
hifigan_generator,
hifigan_generator_v1,
hifigan_generator_v2,
hifigan_generator_v3,
)
from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
...@@ -36,7 +31,7 @@ class HiFiGANTestImpl(TestBaseMixin): ...@@ -36,7 +31,7 @@ class HiFiGANTestImpl(TestBaseMixin):
} }
def _get_model(self): def _get_model(self):
return hifigan_generator(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval() return hifigan_vocoder(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
def _get_inputs(self): def _get_inputs(self):
input_config = self._get_input_config() input_config = self._get_input_config()
...@@ -51,7 +46,7 @@ class HiFiGANTestImpl(TestBaseMixin): ...@@ -51,7 +46,7 @@ class HiFiGANTestImpl(TestBaseMixin):
super().setUp() super().setUp()
torch.random.manual_seed(31) torch.random.manual_seed(31)
@parameterized.expand([(hifigan_generator_v1,), (hifigan_generator_v2,), (hifigan_generator_v3,)]) @parameterized.expand([(hifigan_vocoder_v1,), (hifigan_vocoder_v2,), (hifigan_vocoder_v3,)])
def test_smoke(self, factory_func): def test_smoke(self, factory_func):
r"""Verify that model architectures V1, V2, V3 can be constructed and applied on inputs""" r"""Verify that model architectures V1, V2, V3 can be constructed and applied on inputs"""
model = factory_func().to(device=self.device, dtype=self.dtype) model = factory_func().to(device=self.device, dtype=self.dtype)
......
...@@ -8,13 +8,7 @@ from ._conformer_wav2vec2 import ( ...@@ -8,13 +8,7 @@ from ._conformer_wav2vec2 import (
) )
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .hifi_gan import ( from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
hifigan_generator,
hifigan_generator_v1,
hifigan_generator_v2,
hifigan_generator_v3,
HiFiGANGenerator,
)
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
...@@ -29,9 +23,9 @@ __all__ = [ ...@@ -29,9 +23,9 @@ __all__ = [
"ConformerWav2Vec2PretrainModel", "ConformerWav2Vec2PretrainModel",
"emformer_hubert_base", "emformer_hubert_base",
"emformer_hubert_model", "emformer_hubert_model",
"HiFiGANGenerator", "HiFiGANVocoder",
"hifigan_generator_v1", "hifigan_vocoder_v1",
"hifigan_generator_v2", "hifigan_vocoder_v2",
"hifigan_generator_v3", "hifigan_vocoder_v3",
"hifigan_generator", "hifigan_vocoder",
] ]
...@@ -30,13 +30,13 @@ import torch.nn.functional as F ...@@ -30,13 +30,13 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
class HiFiGANGenerator(torch.nn.Module): class HiFiGANVocoder(torch.nn.Module):
"""Generator part of *HiFi GAN* :cite:`NEURIPS2020_c5d73680`. """Generator part of *HiFi GAN* :cite:`NEURIPS2020_c5d73680`.
Source: https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75 Source: https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75
Note: Note:
To build the model, please use one of the factory functions: :py:func:`hifigan_generator`, To build the model, please use one of the factory functions: :py:func:`hifigan_vocoder`,
:py:func:`hifigan_generator_v1`, :py:func:`hifigan_generator_v2`, :py:func:`hifigan_generator_v3`. :py:func:`hifigan_vocoder_v1`, :py:func:`hifigan_vocoder_v2`, :py:func:`hifigan_vocoder_v3`.
Args: Args:
in_channels (int): Number of channels in the input features. in_channels (int): Number of channels in the input features.
...@@ -62,7 +62,7 @@ class HiFiGANGenerator(torch.nn.Module): ...@@ -62,7 +62,7 @@ class HiFiGANGenerator(torch.nn.Module):
resblock_type: int, resblock_type: int,
lrelu_slope: float, lrelu_slope: float,
): ):
super(HiFiGANGenerator, self).__init__() super(HiFiGANVocoder, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) self.conv_pre = Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
...@@ -117,7 +117,7 @@ class HiFiGANGenerator(torch.nn.Module): ...@@ -117,7 +117,7 @@ class HiFiGANGenerator(torch.nn.Module):
@torch.jit.interface @torch.jit.interface
class ResBlockInterface(torch.nn.Module): class ResBlockInterface(torch.nn.Module):
"""Interface for ResBlock - necessary to make type annotations in ``HiFiGANGenerator.forward`` compatible """Interface for ResBlock - necessary to make type annotations in ``HiFiGANVocoder.forward`` compatible
with TorchScript with TorchScript
""" """
...@@ -126,7 +126,7 @@ class ResBlockInterface(torch.nn.Module): ...@@ -126,7 +126,7 @@ class ResBlockInterface(torch.nn.Module):
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):
"""Residual block of type 1 for HiFiGAN Generator :cite:`NEURIPS2020_c5d73680`. """Residual block of type 1 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`.
Args: Args:
channels (int): Number of channels in the input features. channels (int): Number of channels in the input features.
kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``) kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
...@@ -193,7 +193,7 @@ class ResBlock1(torch.nn.Module): ...@@ -193,7 +193,7 @@ class ResBlock1(torch.nn.Module):
class ResBlock2(torch.nn.Module): class ResBlock2(torch.nn.Module):
"""Residual block of type 2 for HiFiGAN Generator :cite:`NEURIPS2020_c5d73680`. """Residual block of type 2 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`.
Args: Args:
channels (int): Number of channels in the input features. channels (int): Number of channels in the input features.
kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``) kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
...@@ -246,7 +246,7 @@ def get_padding(kernel_size, dilation=1): ...@@ -246,7 +246,7 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
def hifigan_generator( def hifigan_vocoder(
in_channels: int, in_channels: int,
upsample_rates: Tuple[int, ...], upsample_rates: Tuple[int, ...],
upsample_initial_channel: int, upsample_initial_channel: int,
...@@ -255,22 +255,22 @@ def hifigan_generator( ...@@ -255,22 +255,22 @@ def hifigan_generator(
resblock_dilation_sizes: Tuple[Tuple[int, ...], ...], resblock_dilation_sizes: Tuple[Tuple[int, ...], ...],
resblock_type: int, resblock_type: int,
lrelu_slope: float, lrelu_slope: float,
) -> HiFiGANGenerator: ) -> HiFiGANVocoder:
r"""Builds HiFi GAN Generator :cite:`NEURIPS2020_c5d73680`. r"""Builds HiFi GAN Vocoder :cite:`NEURIPS2020_c5d73680`.
Args: Args:
in_channels (int): See :py:class:`HiFiGANGenerator`. in_channels (int): See :py:class:`HiFiGANVocoder`.
upsample_rates (tuple of ``int``): See :py:class:`HiFiGANGenerator`. upsample_rates (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
upsample_initial_channel (int): See :py:class:`HiFiGANGenerator`. upsample_initial_channel (int): See :py:class:`HiFiGANVocoder`.
upsample_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANGenerator`. upsample_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
resblock_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANGenerator`. resblock_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
resblock_dilation_sizes (tuple of tuples of ``int``): See :py:class:`HiFiGANGenerator`. resblock_dilation_sizes (tuple of tuples of ``int``): See :py:class:`HiFiGANVocoder`.
resblock_type (int, 1 or 2): See :py:class:`HiFiGANGenerator`. resblock_type (int, 1 or 2): See :py:class:`HiFiGANVocoder`.
Returns: Returns:
HiFiGANGenerator: generated model. HiFiGANVocoder: generated model.
""" """
return HiFiGANGenerator( return HiFiGANVocoder(
upsample_rates=upsample_rates, upsample_rates=upsample_rates,
resblock_kernel_sizes=resblock_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes, resblock_dilation_sizes=resblock_dilation_sizes,
...@@ -282,13 +282,13 @@ def hifigan_generator( ...@@ -282,13 +282,13 @@ def hifigan_generator(
) )
def hifigan_generator_v1() -> HiFiGANGenerator: def hifigan_vocoder_v1() -> HiFiGANVocoder:
r"""Builds HiFiGAN Generator with V1 architecture :cite:`NEURIPS2020_c5d73680`. r"""Builds HiFiGAN Vocoder with V1 architecture :cite:`NEURIPS2020_c5d73680`.
Returns: Returns:
HiFiGANGenerator: generated model. HiFiGANVocoder: generated model.
""" """
return hifigan_generator( return hifigan_vocoder(
upsample_rates=(8, 8, 2, 2), upsample_rates=(8, 8, 2, 2),
upsample_kernel_sizes=(16, 16, 4, 4), upsample_kernel_sizes=(16, 16, 4, 4),
upsample_initial_channel=512, upsample_initial_channel=512,
...@@ -300,13 +300,13 @@ def hifigan_generator_v1() -> HiFiGANGenerator: ...@@ -300,13 +300,13 @@ def hifigan_generator_v1() -> HiFiGANGenerator:
) )
def hifigan_generator_v2() -> HiFiGANGenerator: def hifigan_vocoder_v2() -> HiFiGANVocoder:
r"""Builds HiFiGAN Generator with V2 architecture :cite:`NEURIPS2020_c5d73680`. r"""Builds HiFiGAN Vocoder with V2 architecture :cite:`NEURIPS2020_c5d73680`.
Returns: Returns:
HiFiGANGenerator: generated model. HiFiGANVocoder: generated model.
""" """
return hifigan_generator( return hifigan_vocoder(
upsample_rates=(8, 8, 2, 2), upsample_rates=(8, 8, 2, 2),
upsample_kernel_sizes=(16, 16, 4, 4), upsample_kernel_sizes=(16, 16, 4, 4),
upsample_initial_channel=128, upsample_initial_channel=128,
...@@ -318,13 +318,13 @@ def hifigan_generator_v2() -> HiFiGANGenerator: ...@@ -318,13 +318,13 @@ def hifigan_generator_v2() -> HiFiGANGenerator:
) )
def hifigan_generator_v3() -> HiFiGANGenerator: def hifigan_vocoder_v3() -> HiFiGANVocoder:
r"""Builds HiFiGAN Generator with V3 architecture :cite:`NEURIPS2020_c5d73680`. r"""Builds HiFiGAN Vocoder with V3 architecture :cite:`NEURIPS2020_c5d73680`.
Returns: Returns:
HiFiGANGenerator: generated model. HiFiGANVocoder: generated model.
""" """
return hifigan_generator( return hifigan_vocoder(
upsample_rates=(8, 8, 4), upsample_rates=(8, 8, 4),
upsample_kernel_sizes=(16, 16, 8), upsample_kernel_sizes=(16, 16, 8),
upsample_initial_channel=256, upsample_initial_channel=256,
......
...@@ -6,14 +6,14 @@ import torch.nn.functional as F ...@@ -6,14 +6,14 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torchaudio._internal import load_state_dict_from_url from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models.hifi_gan import hifigan_generator, HiFiGANGenerator from torchaudio.prototype.models.hifi_gan import hifigan_vocoder, HiFiGANVocoder
from torchaudio.transforms import MelSpectrogram from torchaudio.transforms import MelSpectrogram
@dataclass @dataclass
class HiFiGANVocoderBundle: class HiFiGANVocoderBundle:
"""Data class that bundles associated information to use pretrained """Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.HiFiGANGenerator`. :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`.
This class provides interfaces for instantiating the pretrained model along with This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data the information necessary to retrieve pretrained weights and additional data
...@@ -35,7 +35,7 @@ class HiFiGANVocoderBundle: ...@@ -35,7 +35,7 @@ class HiFiGANVocoderBundle:
>>> >>>
>>> # Load the HiFiGAN bundle >>> # Load the HiFiGAN bundle
>>> vocoder = bundle.get_vocoder() >>> vocoder = bundle.get_vocoder()
Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_generator_v3_ljspeech.pth" Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth"
100%|████████████| 5.59M/5.59M [00:00<00:00, 18.7MB/s] 100%|████████████| 5.59M/5.59M [00:00<00:00, 18.7MB/s]
>>> >>>
>>> # Generate synthetic mel spectrogram >>> # Generate synthetic mel spectrogram
...@@ -63,7 +63,7 @@ class HiFiGANVocoderBundle: ...@@ -63,7 +63,7 @@ class HiFiGANVocoderBundle:
>>> >>>
>>> # Load HiFiGAN bundle >>> # Load HiFiGAN bundle
>>> vocoder = bundle_hifigan.get_vocoder() >>> vocoder = bundle_hifigan.get_vocoder()
Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_generator_v3_ljspeech.pth" Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth"
100%|████████████| 5.59M/5.59M [00:03<00:00, 1.55MB/s] 100%|████████████| 5.59M/5.59M [00:03<00:00, 1.55MB/s]
>>> >>>
>>> # Use HiFiGAN to convert mel spectrogram to audio >>> # Use HiFiGAN to convert mel spectrogram to audio
...@@ -82,7 +82,7 @@ class HiFiGANVocoderBundle: ...@@ -82,7 +82,7 @@ class HiFiGANVocoderBundle:
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict return state_dict
def get_vocoder(self, *, dl_kwargs=None) -> HiFiGANGenerator: def get_vocoder(self, *, dl_kwargs=None) -> HiFiGANVocoder:
"""Construct the HiFiGAN Generator model, which can be used a vocoder, and load the pretrained weight. """Construct the HiFiGAN Generator model, which can be used a vocoder, and load the pretrained weight.
The weight file is downloaded from the internet and cached with The weight file is downloaded from the internet and cached with
...@@ -92,9 +92,9 @@ class HiFiGANVocoderBundle: ...@@ -92,9 +92,9 @@ class HiFiGANVocoderBundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns: Returns:
Variation of :py:class:`~torchaudio.prototype.models.HiFiGANGenerator`. Variation of :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`.
""" """
model = hifigan_generator(**self._vocoder_params) model = hifigan_vocoder(**self._vocoder_params)
model.load_state_dict(self._get_state_dict(dl_kwargs)) model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval() model.eval()
return model return model
...@@ -186,7 +186,7 @@ class _HiFiGANMelSpectrogram(torch.nn.Module): ...@@ -186,7 +186,7 @@ class _HiFiGANMelSpectrogram(torch.nn.Module):
HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle( HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle(
"hifigan_generator_v3_ljspeech.pth", "hifigan_vocoder_v3_ljspeech.pth",
_vocoder_params={ _vocoder_params={
"upsample_rates": (8, 8, 4), "upsample_rates": (8, 8, 4),
"upsample_kernel_sizes": (16, 16, 8), "upsample_kernel_sizes": (16, 16, 8),
...@@ -219,7 +219,7 @@ HIFIGAN_VOCODER_V3_LJSPEECH.__doc__ = """HiFiGAN Vocoder pipeline, trained on *T ...@@ -219,7 +219,7 @@ HIFIGAN_VOCODER_V3_LJSPEECH.__doc__ = """HiFiGAN Vocoder pipeline, trained on *T
<https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72>`_. <https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72>`_.
The underlying vocoder is constructed by The underlying vocoder is constructed by
:py:func:`torchaudio.prototype.models.hifigan_generator`. The weights are converted from the ones published :py:func:`torchaudio.prototype.models.hifigan_vocoder`. The weights are converted from the ones published
with the original paper :cite:`NEURIPS2020_c5d73680` under `MIT License with the original paper :cite:`NEURIPS2020_c5d73680` under `MIT License
<https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/LICENSE>`__. See links to <https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/LICENSE>`__. See links to
pre-trained models on `GitHub <https://github.com/jik876/hifi-gan#pretrained-model>`__. pre-trained models on `GitHub <https://github.com/jik876/hifi-gan#pretrained-model>`__.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment