Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
from .interface import Tacotron2TTSBundle
from .impl import ( from .impl import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH, TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH, TACOTRON2_WAVERNN_PHONE_LJSPEECH,
) )
from .interface import Tacotron2TTSBundle
__all__ = [ __all__ = [
'Tacotron2TTSBundle', "Tacotron2TTSBundle",
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH', "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH', "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
'TACOTRON2_WAVERNN_CHAR_LJSPEECH', "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
'TACOTRON2_WAVERNN_PHONE_LJSPEECH', "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
] ]
from dataclasses import dataclass
import re import re
from dataclasses import dataclass
from typing import Union, Optional, Dict, Any, Tuple, List from typing import Union, Optional, Dict, Any, Tuple, List
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio._internal import load_state_dict_from_url from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.functional import mu_law_decoding from torchaudio.functional import mu_law_decoding
from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.transforms import InverseMelScale, GriffinLim from torchaudio.transforms import InverseMelScale, GriffinLim
from . import utils from . import utils
from .interface import Tacotron2TTSBundle from .interface import Tacotron2TTSBundle
__all__ = [] __all__ = []
_BASE_URL = 'https://download.pytorch.org/torchaudio/models' _BASE_URL = "https://download.pytorch.org/torchaudio/models"
################################################################################ ################################################################################
...@@ -44,8 +44,7 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor): ...@@ -44,8 +44,7 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
super().__init__() super().__init__()
self._tokens = utils._get_phones() self._tokens = utils._get_phones()
self._mapping = {p: i for i, p in enumerate(self._tokens)} self._mapping = {p: i for i, p in enumerate(self._tokens)}
self._phonemizer = utils._load_phonemizer( self._phonemizer = utils._load_phonemizer("en_us_cmudict_forward.pt", dl_kwargs=dl_kwargs)
'en_us_cmudict_forward.pt', dl_kwargs=dl_kwargs)
self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])" self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
@property @property
...@@ -57,9 +56,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor): ...@@ -57,9 +56,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
texts = [texts] texts = [texts]
indices = [] indices = []
for phones in self._phonemizer(texts, lang='en_us'): for phones in self._phonemizer(texts, lang="en_us"):
# '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!'] # '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
ret = [re.sub(r'[\[\]]', '', r) for r in re.findall(self._pattern, phones)] ret = [re.sub(r"[\[\]]", "", r) for r in re.findall(self._pattern, phones)]
indices.append([self._mapping[p] for p in ret]) indices.append([self._mapping[p] for p in ret])
return utils._to_tensor(indices) return utils._to_tensor(indices)
...@@ -68,12 +67,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor): ...@@ -68,12 +67,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
# Pipeline implementation - Vocoder # Pipeline implementation - Vocoder
################################################################################ ################################################################################
class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def __init__( def __init__(self, model: WaveRNN, min_level_db: Optional[float] = -100):
self,
model: WaveRNN,
min_level_db: Optional[float] = -100
):
super().__init__() super().__init__()
self._sample_rate = 22050 self._sample_rate = 22050
self._model = model self._model = model
...@@ -104,10 +100,10 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): ...@@ -104,10 +100,10 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
n_stft=(1024 // 2 + 1), n_stft=(1024 // 2 + 1),
n_mels=80, n_mels=80,
sample_rate=self.sample_rate, sample_rate=self.sample_rate,
f_min=0., f_min=0.0,
f_max=8000., f_max=8000.0,
mel_scale="slaney", mel_scale="slaney",
norm='slaney', norm="slaney",
) )
self._griffin_lim = GriffinLim( self._griffin_lim = GriffinLim(
n_fft=1024, n_fft=1024,
...@@ -151,7 +147,7 @@ class _Tacotron2Mixin: ...@@ -151,7 +147,7 @@ class _Tacotron2Mixin:
def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2: def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
model = Tacotron2(**self._tacotron2_params) model = Tacotron2(**self._tacotron2_params)
url = f'{_BASE_URL}/{self._tacotron2_path}' url = f"{_BASE_URL}/{self._tacotron2_path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -170,7 +166,7 @@ class _WaveRNNMixin: ...@@ -170,7 +166,7 @@ class _WaveRNNMixin:
def _get_wavernn(self, *, dl_kwargs=None): def _get_wavernn(self, *, dl_kwargs=None):
model = WaveRNN(**self._wavernn_params) model = WaveRNN(**self._wavernn_params)
url = f'{_BASE_URL}/{self._wavernn_path}' url = f"{_BASE_URL}/{self._wavernn_path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -214,11 +210,10 @@ class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneM ...@@ -214,11 +210,10 @@ class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneM
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle( TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_ljspeech.pth', _tacotron2_path="tacotron2_english_characters_1500_epochs_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=38), _tacotron2_params=utils._get_taco_params(n_symbols=38),
) )
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = ( TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`. :py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts character-by-character. The text processor encodes the input texts character-by-character.
...@@ -254,14 +249,13 @@ Example - "The examination and testimony of the experts enabled the Commission t ...@@ -254,14 +249,13 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.wav" type="audio/wav"> <source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element. Your browser does not support the <code>audio</code> element.
</audio> </audio>
''') # noqa: E501 """ # noqa: E501
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle( TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_ljspeech.pth', _tacotron2_path="tacotron2_english_phonemes_1500_epochs_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=96), _tacotron2_params=utils._get_taco_params(n_symbols=96),
) )
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = ( TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`. :py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts based on phoneme. The text processor encodes the input texts based on phoneme.
...@@ -302,16 +296,15 @@ Example - "The examination and testimony of the experts enabled the Commission t ...@@ -302,16 +296,15 @@ Example - "The examination and testimony of the experts enabled the Commission t
Your browser does not support the <code>audio</code> element. Your browser does not support the <code>audio</code> element.
</audio> </audio>
''') # noqa: E501 """ # noqa: E501
TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle( TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth', _tacotron2_path="tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=38), _tacotron2_params=utils._get_taco_params(n_symbols=38),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth', _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
_wavernn_params=utils._get_wrnn_params(), _wavernn_params=utils._get_wrnn_params(),
) )
TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = ( TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`. :py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts character-by-character. The text processor encodes the input texts character-by-character.
...@@ -350,16 +343,15 @@ Example - "The examination and testimony of the experts enabled the Commission t ...@@ -350,16 +343,15 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.wav" type="audio/wav"> <source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element. Your browser does not support the <code>audio</code> element.
</audio> </audio>
''') # noqa: E501 """ # noqa: E501
TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle( TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth', _tacotron2_path="tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=96), _tacotron2_params=utils._get_taco_params(n_symbols=96),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth', _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
_wavernn_params=utils._get_wrnn_params(), _wavernn_params=utils._get_wrnn_params(),
) )
TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = ( TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`. :py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts based on phoneme. The text processor encodes the input texts based on phoneme.
...@@ -403,4 +395,4 @@ Example - "The examination and testimony of the experts enabled the Commission t ...@@ -403,4 +395,4 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.wav" type="audio/wav"> <source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element. Your browser does not support the <code>audio</code> element.
</audio> </audio>
''') # noqa: E501 """ # noqa: E501
...@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Optional from typing import Union, List, Tuple, Optional
from torch import Tensor from torch import Tensor
from torchaudio.models import Tacotron2 from torchaudio.models import Tacotron2
......
import os
import logging import logging
import os
import torch import torch
from torchaudio._internal import ( from torchaudio._internal import (
download_url_to_file, download_url_to_file,
module_utils as _mod_utils, module_utils as _mod_utils,
...@@ -11,44 +10,44 @@ from torchaudio._internal import ( ...@@ -11,44 +10,44 @@ from torchaudio._internal import (
def _get_chars(): def _get_chars():
return ( return (
'_', "_",
'-', "-",
'!', "!",
"'", "'",
'(', "(",
')', ")",
',', ",",
'.', ".",
':', ":",
';', ";",
'?', "?",
' ', " ",
'a', "a",
'b', "b",
'c', "c",
'd', "d",
'e', "e",
'f', "f",
'g', "g",
'h', "h",
'i', "i",
'j', "j",
'k', "k",
'l', "l",
'm', "m",
'n', "n",
'o', "o",
'p', "p",
'q', "q",
'r', "r",
's', "s",
't', "t",
'u', "u",
'v', "v",
'w', "w",
'x', "x",
'y', "y",
'z', "z",
) )
...@@ -149,7 +148,7 @@ def _get_phones(): ...@@ -149,7 +148,7 @@ def _get_phones():
"W", "W",
"Y", "Y",
"Z", "Z",
"ZH" "ZH",
) )
...@@ -161,18 +160,18 @@ def _to_tensor(indices): ...@@ -161,18 +160,18 @@ def _to_tensor(indices):
def _load_phonemizer(file, dl_kwargs): def _load_phonemizer(file, dl_kwargs):
if not _mod_utils.is_module_available('dp'): if not _mod_utils.is_module_available("dp"):
raise RuntimeError('DeepPhonemizer is not installed. Please install it.') raise RuntimeError("DeepPhonemizer is not installed. Please install it.")
from dp.phonemizer import Phonemizer from dp.phonemizer import Phonemizer
# By default, dp issues DEBUG level log. # By default, dp issues DEBUG level log.
logger = logging.getLogger('dp') logger = logging.getLogger("dp")
orig_level = logger.level orig_level = logger.level
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
try: try:
url = f'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}' url = f"https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}"
directory = os.path.join(torch.hub.get_dir(), 'checkpoints') directory = os.path.join(torch.hub.get_dir(), "checkpoints")
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
path = os.path.join(directory, file) path = os.path.join(directory, file)
if not os.path.exists(path): if not os.path.exists(path):
...@@ -192,41 +191,41 @@ def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -192,41 +191,41 @@ def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
def _get_taco_params(n_symbols): def _get_taco_params(n_symbols):
return { return {
'mask_padding': False, "mask_padding": False,
'n_mels': 80, "n_mels": 80,
'n_frames_per_step': 1, "n_frames_per_step": 1,
'symbol_embedding_dim': 512, "symbol_embedding_dim": 512,
'encoder_embedding_dim': 512, "encoder_embedding_dim": 512,
'encoder_n_convolution': 3, "encoder_n_convolution": 3,
'encoder_kernel_size': 5, "encoder_kernel_size": 5,
'decoder_rnn_dim': 1024, "decoder_rnn_dim": 1024,
'decoder_max_step': 2000, "decoder_max_step": 2000,
'decoder_dropout': 0.1, "decoder_dropout": 0.1,
'decoder_early_stopping': True, "decoder_early_stopping": True,
'attention_rnn_dim': 1024, "attention_rnn_dim": 1024,
'attention_hidden_dim': 128, "attention_hidden_dim": 128,
'attention_location_n_filter': 32, "attention_location_n_filter": 32,
'attention_location_kernel_size': 31, "attention_location_kernel_size": 31,
'attention_dropout': 0.1, "attention_dropout": 0.1,
'prenet_dim': 256, "prenet_dim": 256,
'postnet_n_convolution': 5, "postnet_n_convolution": 5,
'postnet_kernel_size': 5, "postnet_kernel_size": 5,
'postnet_embedding_dim': 512, "postnet_embedding_dim": 512,
'gate_threshold': 0.5, "gate_threshold": 0.5,
'n_symbol': n_symbols, "n_symbol": n_symbols,
} }
def _get_wrnn_params(): def _get_wrnn_params():
return { return {
'upsample_scales': [5, 5, 11], "upsample_scales": [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8 "n_classes": 2 ** 8, # n_bits = 8
'hop_length': 275, "hop_length": 275,
'n_res_block': 10, "n_res_block": 10,
'n_rnn': 512, "n_rnn": 512,
'n_fc': 512, "n_fc": 512,
'kernel_size': 5, "kernel_size": 5,
'n_freq': 80, "n_freq": 80,
'n_hidden': 128, "n_hidden": 128,
'n_output': 128 "n_output": 128,
} }
...@@ -2,9 +2,9 @@ from dataclasses import dataclass ...@@ -2,9 +2,9 @@ from dataclasses import dataclass
from typing import Dict, Tuple, Any from typing import Dict, Tuple, Any
import torch import torch
from torchaudio._internal import load_state_dict_from_url from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model from torchaudio.models import wav2vec2_model, Wav2Vec2Model
from . import utils from . import utils
...@@ -43,6 +43,7 @@ class Wav2Vec2Bundle: ...@@ -43,6 +43,7 @@ class Wav2Vec2Bundle:
>>> # Extract acoustic features >>> # Extract acoustic features
>>> features, _ = model.extract_features(waveform) >>> features, _ = model.extract_features(waveform)
""" # noqa: E501 """ # noqa: E501
_path: str _path: str
_params: Dict[str, Any] _params: Dict[str, Any]
_sample_rate: float _sample_rate: float
...@@ -56,7 +57,7 @@ class Wav2Vec2Bundle: ...@@ -56,7 +57,7 @@ class Wav2Vec2Bundle:
return self._sample_rate return self._sample_rate
def _get_state_dict(self, dl_kwargs): def _get_state_dict(self, dl_kwargs):
url = f'https://download.pytorch.org/torchaudio/models/{self._path}' url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict return state_dict
...@@ -120,13 +121,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -120,13 +121,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> # `ctc_decode` is for illustration purpose only >>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels) >>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501 """ # noqa: E501
_labels: Tuple[str] _labels: Tuple[str]
_remove_aux_axis: Tuple[int] = (1, 2, 3) _remove_aux_axis: Tuple[int] = (1, 2, 3)
def get_labels( def get_labels(
self, self,
*, *,
blank: str = '-', blank: str = "-",
) -> Tuple[str]: ) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles) """The output class labels (only applicable to fine-tuned bundles)
...@@ -161,17 +163,17 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -161,17 +163,17 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
# that resembles mistake. # that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M), # The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M) # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ['aux.weight', 'aux.bias']: for key in ["aux.weight", "aux.bias"]:
t = state_dict[key] t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis]) state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict return state_dict
WAV2VEC2_BASE = Wav2Vec2Bundle( WAV2VEC2_BASE = Wav2Vec2Bundle(
_path='wav2vec2_fairseq_base_ls960.pth', _path="wav2vec2_fairseq_base_ls960.pth",
_params={ _params={
'extractor_mode': 'group_norm', "extractor_mode": "group_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -180,19 +182,19 @@ WAV2VEC2_BASE = Wav2Vec2Bundle( ...@@ -180,19 +182,19 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 768, "encoder_embed_dim": 768,
'encoder_projection_dropout': 0.1, "encoder_projection_dropout": 0.1,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 12, "encoder_num_layers": 12,
'encoder_num_heads': 12, "encoder_num_heads": 12,
'encoder_attention_dropout': 0.1, "encoder_attention_dropout": 0.1,
'encoder_ff_interm_features': 3072, "encoder_ff_interm_features": 3072,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.1, "encoder_dropout": 0.1,
'encoder_layer_norm_first': False, "encoder_layer_norm_first": False,
'encoder_layer_drop': 0.05, "encoder_layer_drop": 0.05,
"aux_num_out": None, "aux_num_out": None,
}, },
_sample_rate=16000, _sample_rate=16000,
...@@ -212,10 +214,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -212,10 +214,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
_path='wav2vec2_fairseq_base_ls960_asr_ll10m.pth', _path="wav2vec2_fairseq_base_ls960_asr_ll10m.pth",
_params={ _params={
'extractor_mode': 'group_norm', "extractor_mode": "group_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -224,19 +226,19 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( ...@@ -224,19 +226,19 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 768, "encoder_embed_dim": 768,
'encoder_projection_dropout': 0.1, "encoder_projection_dropout": 0.1,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 12, "encoder_num_layers": 12,
'encoder_num_heads': 12, "encoder_num_heads": 12,
'encoder_attention_dropout': 0.1, "encoder_attention_dropout": 0.1,
'encoder_ff_interm_features': 3072, "encoder_ff_interm_features": 3072,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.1, "encoder_dropout": 0.1,
'encoder_layer_norm_first': False, "encoder_layer_norm_first": False,
'encoder_layer_drop': 0.05, "encoder_layer_drop": 0.05,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=utils._get_en_labels(), _labels=utils._get_en_labels(),
...@@ -258,10 +260,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -258,10 +260,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_base_ls960_asr_ls100.pth', "wav2vec2_fairseq_base_ls960_asr_ls100.pth",
{ {
'extractor_mode': 'group_norm', "extractor_mode": "group_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -270,19 +272,19 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( ...@@ -270,19 +272,19 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 768, "encoder_embed_dim": 768,
'encoder_projection_dropout': 0.1, "encoder_projection_dropout": 0.1,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 12, "encoder_num_layers": 12,
'encoder_num_heads': 12, "encoder_num_heads": 12,
'encoder_attention_dropout': 0.1, "encoder_attention_dropout": 0.1,
'encoder_ff_interm_features': 3072, "encoder_ff_interm_features": 3072,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.1, "encoder_dropout": 0.1,
'encoder_layer_norm_first': False, "encoder_layer_norm_first": False,
'encoder_layer_drop': 0.05, "encoder_layer_drop": 0.05,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=utils._get_en_labels(), _labels=utils._get_en_labels(),
...@@ -304,7 +306,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -304,7 +306,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_base_ls960_asr_ls960.pth', "wav2vec2_fairseq_base_ls960_asr_ls960.pth",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -349,7 +351,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -349,7 +351,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_LARGE = Wav2Vec2Bundle( WAV2VEC2_LARGE = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_ls960.pth', "wav2vec2_fairseq_large_ls960.pth",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -393,7 +395,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -393,7 +395,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ll10m.pth', "wav2vec2_fairseq_large_ls960_asr_ll10m.pth",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -439,7 +441,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -439,7 +441,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ls100.pth', "wav2vec2_fairseq_large_ls960_asr_ls100.pth",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -485,7 +487,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -485,7 +487,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ls960.pth', "wav2vec2_fairseq_large_ls960_asr_ls960.pth",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -530,7 +532,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -530,7 +532,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle( WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_lv60k.pth', "wav2vec2_fairseq_large_lv60k.pth",
{ {
"extractor_mode": "layer_norm", "extractor_mode": "layer_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -574,7 +576,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -574,7 +576,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ll10m.pth', "wav2vec2_fairseq_large_lv60k_asr_ll10m.pth",
{ {
"extractor_mode": "layer_norm", "extractor_mode": "layer_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -620,7 +622,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -620,7 +622,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ls100.pth', "wav2vec2_fairseq_large_lv60k_asr_ls100.pth",
{ {
"extractor_mode": "layer_norm", "extractor_mode": "layer_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -666,7 +668,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -666,7 +668,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle( WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ls960.pth', "wav2vec2_fairseq_large_lv60k_asr_ls960.pth",
{ {
"extractor_mode": "layer_norm", "extractor_mode": "layer_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -713,7 +715,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -713,7 +715,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_XLSR53 = Wav2Vec2Bundle( WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_xlsr53.pth', "wav2vec2_fairseq_large_xlsr53.pth",
{ {
"extractor_mode": "layer_norm", "extractor_mode": "layer_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -760,10 +762,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -760,10 +762,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
HUBERT_BASE = Wav2Vec2Bundle( HUBERT_BASE = Wav2Vec2Bundle(
'hubert_fairseq_base_ls960.pth', "hubert_fairseq_base_ls960.pth",
{ {
'extractor_mode': 'group_norm', "extractor_mode": "group_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -772,20 +774,20 @@ HUBERT_BASE = Wav2Vec2Bundle( ...@@ -772,20 +774,20 @@ HUBERT_BASE = Wav2Vec2Bundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 768, "encoder_embed_dim": 768,
'encoder_projection_dropout': 0.1, "encoder_projection_dropout": 0.1,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 12, "encoder_num_layers": 12,
'encoder_num_heads': 12, "encoder_num_heads": 12,
'encoder_attention_dropout': 0.1, "encoder_attention_dropout": 0.1,
'encoder_ff_interm_features': 3072, "encoder_ff_interm_features": 3072,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.1, "encoder_dropout": 0.1,
'encoder_layer_norm_first': False, "encoder_layer_norm_first": False,
'encoder_layer_drop': 0.05, "encoder_layer_drop": 0.05,
'aux_num_out': None, "aux_num_out": None,
}, },
_sample_rate=16000, _sample_rate=16000,
) )
...@@ -804,10 +806,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -804,10 +806,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
HUBERT_LARGE = Wav2Vec2Bundle( HUBERT_LARGE = Wav2Vec2Bundle(
'hubert_fairseq_large_ll60k.pth', "hubert_fairseq_large_ll60k.pth",
{ {
'extractor_mode': 'layer_norm', "extractor_mode": "layer_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -816,20 +818,20 @@ HUBERT_LARGE = Wav2Vec2Bundle( ...@@ -816,20 +818,20 @@ HUBERT_LARGE = Wav2Vec2Bundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 1024, "encoder_embed_dim": 1024,
'encoder_projection_dropout': 0.0, "encoder_projection_dropout": 0.0,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 24, "encoder_num_layers": 24,
'encoder_num_heads': 16, "encoder_num_heads": 16,
'encoder_attention_dropout': 0.0, "encoder_attention_dropout": 0.0,
'encoder_ff_interm_features': 4096, "encoder_ff_interm_features": 4096,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.0, "encoder_dropout": 0.0,
'encoder_layer_norm_first': True, "encoder_layer_norm_first": True,
'encoder_layer_drop': 0.0, "encoder_layer_drop": 0.0,
'aux_num_out': None, "aux_num_out": None,
}, },
_sample_rate=16000, _sample_rate=16000,
) )
...@@ -848,10 +850,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -848,10 +850,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
HUBERT_XLARGE = Wav2Vec2Bundle( HUBERT_XLARGE = Wav2Vec2Bundle(
'hubert_fairseq_xlarge_ll60k.pth', "hubert_fairseq_xlarge_ll60k.pth",
{ {
'extractor_mode': 'layer_norm', "extractor_mode": "layer_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -860,20 +862,20 @@ HUBERT_XLARGE = Wav2Vec2Bundle( ...@@ -860,20 +862,20 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 1280, "encoder_embed_dim": 1280,
'encoder_projection_dropout': 0.0, "encoder_projection_dropout": 0.0,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 48, "encoder_num_layers": 48,
'encoder_num_heads': 16, "encoder_num_heads": 16,
'encoder_attention_dropout': 0.0, "encoder_attention_dropout": 0.0,
'encoder_ff_interm_features': 5120, "encoder_ff_interm_features": 5120,
'encoder_ff_interm_dropout': 0.0, "encoder_ff_interm_dropout": 0.0,
'encoder_dropout': 0.0, "encoder_dropout": 0.0,
'encoder_layer_norm_first': True, "encoder_layer_norm_first": True,
'encoder_layer_drop': 0.0, "encoder_layer_drop": 0.0,
'aux_num_out': None, "aux_num_out": None,
}, },
_sample_rate=16000, _sample_rate=16000,
) )
...@@ -892,10 +894,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. ...@@ -892,10 +894,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'hubert_fairseq_large_ll60k_asr_ls960.pth', "hubert_fairseq_large_ll60k_asr_ls960.pth",
{ {
'extractor_mode': 'layer_norm', "extractor_mode": "layer_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -904,20 +906,20 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( ...@@ -904,20 +906,20 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 1024, "encoder_embed_dim": 1024,
'encoder_projection_dropout': 0.0, "encoder_projection_dropout": 0.0,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 24, "encoder_num_layers": 24,
'encoder_num_heads': 16, "encoder_num_heads": 16,
'encoder_attention_dropout': 0.0, "encoder_attention_dropout": 0.0,
'encoder_ff_interm_features': 4096, "encoder_ff_interm_features": 4096,
'encoder_ff_interm_dropout': 0.1, "encoder_ff_interm_dropout": 0.1,
'encoder_dropout': 0.0, "encoder_dropout": 0.0,
'encoder_layer_norm_first': True, "encoder_layer_norm_first": True,
'encoder_layer_drop': 0.1, "encoder_layer_drop": 0.1,
'aux_num_out': 29, "aux_num_out": 29,
}, },
_labels=utils._get_en_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -939,10 +941,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -939,10 +941,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'hubert_fairseq_xlarge_ll60k_asr_ls960.pth', "hubert_fairseq_xlarge_ll60k_asr_ls960.pth",
{ {
'extractor_mode': 'layer_norm', "extractor_mode": "layer_norm",
'extractor_conv_layer_config': [ "extractor_conv_layer_config": [
(512, 10, 5), (512, 10, 5),
(512, 3, 2), (512, 3, 2),
(512, 3, 2), (512, 3, 2),
...@@ -951,20 +953,20 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( ...@@ -951,20 +953,20 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
(512, 2, 2), (512, 2, 2),
(512, 2, 2), (512, 2, 2),
], ],
'extractor_conv_bias': False, "extractor_conv_bias": False,
'encoder_embed_dim': 1280, "encoder_embed_dim": 1280,
'encoder_projection_dropout': 0.0, "encoder_projection_dropout": 0.0,
'encoder_pos_conv_kernel': 128, "encoder_pos_conv_kernel": 128,
'encoder_pos_conv_groups': 16, "encoder_pos_conv_groups": 16,
'encoder_num_layers': 48, "encoder_num_layers": 48,
'encoder_num_heads': 16, "encoder_num_heads": 16,
'encoder_attention_dropout': 0.0, "encoder_attention_dropout": 0.0,
'encoder_ff_interm_features': 5120, "encoder_ff_interm_features": 5120,
'encoder_ff_interm_dropout': 0.1, "encoder_ff_interm_dropout": 0.1,
'encoder_dropout': 0.0, "encoder_dropout": 0.0,
'encoder_layer_norm_first': True, "encoder_layer_norm_first": True,
'encoder_layer_drop': 0.1, "encoder_layer_drop": 0.1,
'aux_num_out': 29, "aux_num_out": 29,
}, },
_labels=utils._get_en_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -987,7 +989,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -987,7 +989,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle( VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_de.pt', "wav2vec2_voxpopuli_base_10k_asr_de.pt",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -1034,7 +1036,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -1034,7 +1036,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle( VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_en.pt', "wav2vec2_voxpopuli_base_10k_asr_en.pt",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -1059,7 +1061,7 @@ VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle( ...@@ -1059,7 +1061,7 @@ VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1, "encoder_layer_drop": 0.1,
"aux_num_out": 28 "aux_num_out": 28,
}, },
_labels=utils._get_vp_en_labels(), _labels=utils._get_vp_en_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -1081,7 +1083,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -1081,7 +1083,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle( VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_es.pt', "wav2vec2_voxpopuli_base_10k_asr_es.pt",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -1106,7 +1108,7 @@ VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle( ...@@ -1106,7 +1108,7 @@ VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1, "encoder_layer_drop": 0.1,
"aux_num_out": 35 "aux_num_out": 35,
}, },
_labels=utils._get_es_labels(), _labels=utils._get_es_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -1127,7 +1129,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -1127,7 +1129,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle( VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_fr.pt', "wav2vec2_voxpopuli_base_10k_asr_fr.pt",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
...@@ -1152,7 +1154,7 @@ VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle( ...@@ -1152,7 +1154,7 @@ VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1, "encoder_layer_drop": 0.1,
"aux_num_out": 43 "aux_num_out": 43,
}, },
_labels=utils._get_fr_labels(), _labels=utils._get_fr_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -1173,7 +1175,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. ...@@ -1173,7 +1175,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle( VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_it.pt', "wav2vec2_voxpopuli_base_10k_asr_it.pt",
{ {
"extractor_mode": "group_norm", "extractor_mode": "group_norm",
"extractor_conv_layer_config": [ "extractor_conv_layer_config": [
......
def _get_en_labels(): def _get_en_labels():
return ( return (
'|', "|",
'E', "E",
'T', "T",
'A', "A",
'O', "O",
'N', "N",
'I', "I",
'H', "H",
'S', "S",
'R', "R",
'D', "D",
'L', "L",
'U', "U",
'M', "M",
'W', "W",
'C', "C",
'F', "F",
'G', "G",
'Y', "Y",
'P', "P",
'B', "B",
'V', "V",
'K', "K",
"'", "'",
'X', "X",
'J', "J",
'Q', "Q",
'Z', "Z",
) )
......
import math import math
import torch
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch
__all__ = ["Conformer"] __all__ = ["Conformer"]
...@@ -12,9 +13,9 @@ PADDING_IDX = 1 ...@@ -12,9 +13,9 @@ PADDING_IDX = 1
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0] batch_size = lengths.shape[0]
max_length = int(torch.max(lengths).item()) max_length = int(torch.max(lengths).item())
padding_mask = torch.arange( padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
max_length, device=lengths.device, dtype=lengths.dtype batch_size, max_length
).expand(batch_size, max_length) >= lengths.unsqueeze(1) ) >= lengths.unsqueeze(1)
return padding_mask return padding_mask
...@@ -31,12 +32,8 @@ def _get_sinusoidal_embeddings( ...@@ -31,12 +32,8 @@ def _get_sinusoidal_embeddings(
from the description in Section 3.5 of "Attention Is All You Need". from the description in Section 3.5 of "Attention Is All You Need".
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
t = ( t = (torch.arange(half_dim, dtype=torch.float) * -math.log(10000) / (half_dim - 1)).exp()
torch.arange(half_dim, dtype=torch.float) * -math.log(10000) / (half_dim - 1) embedding_t = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * t.unsqueeze(0)
).exp()
embedding_t = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * t.unsqueeze(0)
embeddings = torch.cat([embedding_t.sin(), embedding_t.cos()], dim=1) embeddings = torch.cat([embedding_t.sin(), embedding_t.cos()], dim=1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
embeddings = torch.cat([embeddings, torch.zeros(num_embeddings, 1)], dim=1) embeddings = torch.cat([embeddings, torch.zeros(num_embeddings, 1)], dim=1)
...@@ -64,13 +61,16 @@ class ConvolutionModule(torch.nn.Module): ...@@ -64,13 +61,16 @@ class ConvolutionModule(torch.nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
depthwise_kernel_size - 1
) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
self.layer_norm = torch.nn.LayerNorm(input_dim) self.layer_norm = torch.nn.LayerNorm(input_dim)
self.sequential = torch.nn.Sequential( self.sequential = torch.nn.Sequential(
torch.nn.Conv1d( torch.nn.Conv1d(
input_dim, 2 * num_channels, 1, stride=1, padding=0, bias=bias, input_dim,
2 * num_channels,
1,
stride=1,
padding=0,
bias=bias,
), ),
torch.nn.GLU(dim=1), torch.nn.GLU(dim=1),
torch.nn.Conv1d( torch.nn.Conv1d(
...@@ -85,7 +85,12 @@ class ConvolutionModule(torch.nn.Module): ...@@ -85,7 +85,12 @@ class ConvolutionModule(torch.nn.Module):
torch.nn.BatchNorm1d(num_channels), torch.nn.BatchNorm1d(num_channels),
torch.nn.SiLU(), torch.nn.SiLU(),
torch.nn.Conv1d( torch.nn.Conv1d(
num_channels, input_dim, kernel_size=1, stride=1, padding=0, bias=bias, num_channels,
input_dim,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
), ),
torch.nn.Dropout(dropout), torch.nn.Dropout(dropout),
) )
...@@ -159,9 +164,7 @@ class ConformerLayer(torch.nn.Module): ...@@ -159,9 +164,7 @@ class ConformerLayer(torch.nn.Module):
self.ffn1 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout) self.ffn1 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim) self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
self.self_attn = torch.nn.MultiheadAttention( self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
input_dim, num_attention_heads, dropout=dropout
)
self.self_attn_dropout = torch.nn.Dropout(dropout) self.self_attn_dropout = torch.nn.Dropout(dropout)
self.conv_module = ConvolutionModule( self.conv_module = ConvolutionModule(
...@@ -173,9 +176,7 @@ class ConformerLayer(torch.nn.Module): ...@@ -173,9 +176,7 @@ class ConformerLayer(torch.nn.Module):
self.ffn2 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout) self.ffn2 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
self.final_layer_norm = torch.nn.LayerNorm(input_dim) self.final_layer_norm = torch.nn.LayerNorm(input_dim)
def forward( def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]
) -> torch.Tensor:
r""" r"""
Args: Args:
input (torch.Tensor): input, with shape `(T, B, D)`. input (torch.Tensor): input, with shape `(T, B, D)`.
...@@ -256,9 +257,7 @@ class Conv1dSubsampler(torch.nn.Module): ...@@ -256,9 +257,7 @@ class Conv1dSubsampler(torch.nn.Module):
out = ((out.float() - 1) / 2 + 1).floor().long() out = ((out.float() - 1) / 2 + 1).floor().long()
return out.to(torch.int32) return out.to(torch.int32)
def forward( def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Args: Args:
input (torch.Tensor): input frames, with shape `(B, T_in, in_channels)`. input (torch.Tensor): input frames, with shape `(B, T_in, in_channels)`.
...@@ -289,15 +288,11 @@ class SinusoidalPositionalEmbedding(torch.nn.Module): ...@@ -289,15 +288,11 @@ class SinusoidalPositionalEmbedding(torch.nn.Module):
init_size (int, optional): initial embedding count. (Default: 1024) init_size (int, optional): initial embedding count. (Default: 1024)
""" """
def __init__( def __init__(self, embedding_dim: int, padding_idx: int = 0, init_size: int = 1024) -> None:
self, embedding_dim: int, padding_idx: int = 0, init_size: int = 1024
) -> None:
super().__init__() super().__init__()
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embeddings = _get_sinusoidal_embeddings( self.embeddings = _get_sinusoidal_embeddings(init_size, embedding_dim, padding_idx)
init_size, embedding_dim, padding_idx
)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
r""" r"""
...@@ -310,14 +305,10 @@ class SinusoidalPositionalEmbedding(torch.nn.Module): ...@@ -310,14 +305,10 @@ class SinusoidalPositionalEmbedding(torch.nn.Module):
B, T = input.shape B, T = input.shape
max_pos = self.padding_idx + 1 + T max_pos = self.padding_idx + 1 + T
if max_pos > self.embeddings.size(0): if max_pos > self.embeddings.size(0):
self.embeddings = _get_sinusoidal_embeddings( self.embeddings = _get_sinusoidal_embeddings(max_pos, self.embedding_dim, self.padding_idx)
max_pos, self.embedding_dim, self.padding_idx
)
self.embeddings = self.embeddings.to(input) self.embeddings = self.embeddings.to(input)
positions = _make_positions(input, self.padding_idx) positions = _make_positions(input, self.padding_idx)
return ( return self.embeddings.index_select(0, positions.view(-1)).view(B, T, -1).detach()
self.embeddings.index_select(0, positions.view(-1)).view(B, T, -1).detach()
)
class Conformer(torch.nn.Module): class Conformer(torch.nn.Module):
...@@ -370,16 +361,17 @@ class Conformer(torch.nn.Module): ...@@ -370,16 +361,17 @@ class Conformer(torch.nn.Module):
super().__init__() super().__init__()
self.subsample = Conv1dSubsampler( self.subsample = Conv1dSubsampler(
input_dim, conv_channels, conformer_layer_input_dim, conv_kernel_sizes, input_dim,
conv_channels,
conformer_layer_input_dim,
conv_kernel_sizes,
) )
self.position_embedding = SinusoidalPositionalEmbedding( self.position_embedding = SinusoidalPositionalEmbedding(
conformer_layer_input_dim, conformer_layer_input_dim,
padding_idx=PADDING_IDX, padding_idx=PADDING_IDX,
init_size=max_source_positions + PADDING_IDX + 1, init_size=max_source_positions + PADDING_IDX + 1,
) )
self.linear = torch.nn.Linear( self.linear = torch.nn.Linear(conformer_layer_input_dim, conformer_layer_input_dim)
conformer_layer_input_dim, conformer_layer_input_dim
)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conformer_layers = torch.nn.ModuleList( self.conformer_layers = torch.nn.ModuleList(
[ [
...@@ -394,9 +386,7 @@ class Conformer(torch.nn.Module): ...@@ -394,9 +386,7 @@ class Conformer(torch.nn.Module):
] ]
) )
def forward( def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Args: Args:
input (torch.Tensor): with shape `(B, T_in, input_dim)`. input (torch.Tensor): with shape `(B, T_in, input_dim)`.
......
...@@ -10,9 +10,9 @@ __all__ = ["Emformer"] ...@@ -10,9 +10,9 @@ __all__ = ["Emformer"]
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0] batch_size = lengths.shape[0]
max_length = int(torch.max(lengths).item()) max_length = int(torch.max(lengths).item())
padding_mask = torch.arange( padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
max_length, device=lengths.device, dtype=lengths.dtype batch_size, max_length
).expand(batch_size, max_length) >= lengths.unsqueeze(1) ) >= lengths.unsqueeze(1)
return padding_mask return padding_mask
...@@ -30,15 +30,8 @@ def _gen_padding_mask( ...@@ -30,15 +30,8 @@ def _gen_padding_mask(
padding_mask = None padding_mask = None
else: else:
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0) right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
left_context_blocks_length = ( left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
left_context_key.size(0) if left_context_key is not None else 0 klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
)
klengths = (
lengths
+ mems.size(0)
+ right_context_blocks_length
+ left_context_blocks_length
)
padding_mask = _lengths_to_padding_mask(lengths=klengths) padding_mask = _lengths_to_padding_mask(lengths=klengths)
return padding_mask return padding_mask
...@@ -54,9 +47,7 @@ def _get_activation_module(activation: str) -> torch.nn.Module: ...@@ -54,9 +47,7 @@ def _get_activation_module(activation: str) -> torch.nn.Module:
raise ValueError(f"Unsupported activation {activation}") raise ValueError(f"Unsupported activation {activation}")
def _get_weight_init_gains( def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
weight_init_scale_strategy: Optional[str], num_layers: int
) -> List[Optional[float]]:
if weight_init_scale_strategy is None: if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)] return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise": elif weight_init_scale_strategy == "depthwise":
...@@ -64,17 +55,13 @@ def _get_weight_init_gains( ...@@ -64,17 +55,13 @@ def _get_weight_init_gains(
elif weight_init_scale_strategy == "constant": elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else: else:
raise ValueError( raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}"
)
def _gen_attention_mask_block( def _gen_attention_mask_block(
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
) -> torch.Tensor: ) -> torch.Tensor:
assert len(col_widths) == len( assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask"
col_mask
), "Length of col_widths must match that of col_mask"
mask_block = [ mask_block = [
torch.ones(num_rows, col_width, device=device) torch.ones(num_rows, col_width, device=device)
...@@ -110,9 +97,7 @@ class _EmformerAttention(torch.nn.Module): ...@@ -110,9 +97,7 @@ class _EmformerAttention(torch.nn.Module):
super().__init__() super().__init__()
if input_dim % num_heads != 0: if input_dim % num_heads != 0:
raise ValueError( raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads})."
)
self.input_dim = input_dim self.input_dim = input_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -127,23 +112,15 @@ class _EmformerAttention(torch.nn.Module): ...@@ -127,23 +112,15 @@ class _EmformerAttention(torch.nn.Module):
self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True) self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
if weight_init_gain: if weight_init_gain:
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
self.emb_to_key_value.weight, gain=weight_init_gain torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
)
torch.nn.init.xavier_uniform_(
self.emb_to_query.weight, gain=weight_init_gain
)
def _gen_key_value( def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, mems: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
T, _, _ = input.shape T, _, _ = input.shape
summary_length = mems.size(0) + 1 summary_length = mems.size(0) + 1
right_ctx_utterance_block = input[: T - summary_length] right_ctx_utterance_block = input[: T - summary_length]
mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block]) mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk( key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
chunks=2, dim=2
)
return key, value return key, value
def _gen_attention_probs( def _gen_attention_probs(
...@@ -153,27 +130,17 @@ class _EmformerAttention(torch.nn.Module): ...@@ -153,27 +130,17 @@ class _EmformerAttention(torch.nn.Module):
padding_mask: Optional[torch.Tensor], padding_mask: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
attention_weights_float = attention_weights.float() attention_weights_float = attention_weights.float()
attention_weights_float = attention_weights_float.masked_fill( attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
attention_mask.unsqueeze(0), self.negative_inf
)
T = attention_weights.size(1) T = attention_weights.size(1)
B = attention_weights.size(0) // self.num_heads B = attention_weights.size(0) // self.num_heads
if padding_mask is not None: if padding_mask is not None:
attention_weights_float = attention_weights_float.view( attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
B, self.num_heads, T, -1
)
attention_weights_float = attention_weights_float.masked_fill( attention_weights_float = attention_weights_float.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
) )
attention_weights_float = attention_weights_float.view( attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
B * self.num_heads, T, -1 attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
) return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
attention_probs = torch.nn.functional.softmax(
attention_weights_float, dim=-1
).type_as(attention_weights)
return torch.nn.functional.dropout(
attention_probs, p=float(self.dropout), training=self.training
)
def _forward_impl( def _forward_impl(
self, self,
...@@ -193,9 +160,7 @@ class _EmformerAttention(torch.nn.Module): ...@@ -193,9 +160,7 @@ class _EmformerAttention(torch.nn.Module):
query = self.emb_to_query(torch.cat([right_context, utterance, summary])) query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
# Compute key and value with [mems, right context, utterance]. # Compute key and value with [mems, right context, utterance].
key, value = self.emb_to_key_value( key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
torch.cat([mems, right_context, utterance])
).chunk(chunks=2, dim=2)
if left_context_key is not None and left_context_val is not None: if left_context_key is not None and left_context_val is not None:
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0) right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
...@@ -203,37 +168,29 @@ class _EmformerAttention(torch.nn.Module): ...@@ -203,37 +168,29 @@ class _EmformerAttention(torch.nn.Module):
[ [
key[: mems.size(0) + right_context_blocks_length], key[: mems.size(0) + right_context_blocks_length],
left_context_key, left_context_key,
key[mems.size(0) + right_context_blocks_length:], key[mems.size(0) + right_context_blocks_length :],
], ],
) )
value = torch.cat( value = torch.cat(
[ [
value[: mems.size(0) + right_context_blocks_length], value[: mems.size(0) + right_context_blocks_length],
left_context_val, left_context_val,
value[mems.size(0) + right_context_blocks_length:], value[mems.size(0) + right_context_blocks_length :],
], ],
) )
# Compute attention weights from query, key, and value. # Compute attention weights from query, key, and value.
reshaped_query, reshaped_key, reshaped_value = [ reshaped_query, reshaped_key, reshaped_value = [
tensor.contiguous() tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
.view(-1, B * self.num_heads, self.input_dim // self.num_heads)
.transpose(0, 1)
for tensor in [query, key, value] for tensor in [query, key, value]
] ]
attention_weights = torch.bmm( attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
)
# Compute padding mask. # Compute padding mask.
padding_mask = _gen_padding_mask( padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
utterance, right_context, summary, lengths, mems, left_context_key
)
# Compute attention probabilities. # Compute attention probabilities.
attention_probs = self._gen_attention_probs( attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
attention_weights, attention_mask, padding_mask
)
# Compute attention. # Compute attention.
attention = torch.bmm(attention_probs, reshaped_value) attention = torch.bmm(attention_probs, reshaped_value)
...@@ -249,7 +206,7 @@ class _EmformerAttention(torch.nn.Module): ...@@ -249,7 +206,7 @@ class _EmformerAttention(torch.nn.Module):
summary_length = summary.size(0) summary_length = summary.size(0)
output_right_context = output_right_context_mems[: T - summary_length] output_right_context = output_right_context_mems[: T - summary_length]
output_mems = output_right_context_mems[T - summary_length:] output_mems = output_right_context_mems[T - summary_length :]
if self.tanh_on_mem: if self.tanh_on_mem:
output_mems = torch.tanh(output_mems) output_mems = torch.tanh(output_mems)
else: else:
...@@ -291,9 +248,7 @@ class _EmformerAttention(torch.nn.Module): ...@@ -291,9 +248,7 @@ class _EmformerAttention(torch.nn.Module):
Tensor Tensor
updated memory elements, with shape `(M, B, D)`. updated memory elements, with shape `(M, B, D)`.
""" """
output, output_mems, _, _ = self._forward_impl( output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
utterance, lengths, right_context, summary, mems, attention_mask
)
return output, output_mems[:-1] return output, output_mems[:-1]
@torch.jit.export @torch.jit.export
...@@ -338,15 +293,8 @@ class _EmformerAttention(torch.nn.Module): ...@@ -338,15 +293,8 @@ class _EmformerAttention(torch.nn.Module):
attention value computed for left context and utterance. attention value computed for left context and utterance.
""" """
query_dim = right_context.size(0) + utterance.size(0) + summary.size(0) query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
key_dim = ( key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
right_context.size(0) attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
+ utterance.size(0)
+ mems.size(0)
+ left_context_key.size(0)
)
attention_mask = torch.zeros(query_dim, key_dim).to(
dtype=torch.bool, device=utterance.device
)
attention_mask[-1, : mems.size(0)] = True attention_mask[-1, : mems.size(0)] = True
output, output_mems, key, value = self._forward_impl( output, output_mems, key, value = self._forward_impl(
utterance, utterance,
...@@ -361,8 +309,8 @@ class _EmformerAttention(torch.nn.Module): ...@@ -361,8 +309,8 @@ class _EmformerAttention(torch.nn.Module):
return ( return (
output, output,
output_mems, output_mems,
key[mems.size(0) + right_context.size(0):], key[mems.size(0) + right_context.size(0) :],
value[mems.size(0) + right_context.size(0):], value[mems.size(0) + right_context.size(0) :],
) )
...@@ -410,9 +358,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -410,9 +358,7 @@ class _EmformerLayer(torch.nn.Module):
negative_inf=negative_inf, negative_inf=negative_inf,
) )
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.memory_op = torch.nn.AvgPool1d( self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
kernel_size=segment_length, stride=segment_length, ceil_mode=True
)
activation_module = _get_activation_module(activation) activation_module = _get_activation_module(activation)
self.pos_ff = torch.nn.Sequential( self.pos_ff = torch.nn.Sequential(
...@@ -433,18 +379,10 @@ class _EmformerLayer(torch.nn.Module): ...@@ -433,18 +379,10 @@ class _EmformerLayer(torch.nn.Module):
self.use_mem = max_memory_size > 0 self.use_mem = max_memory_size > 0
def _init_state( def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
self, batch_size: int, device: Optional[torch.device] empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
) -> List[torch.Tensor]: left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
empty_memory = torch.zeros( left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
self.max_memory_size, batch_size, self.input_dim, device=device
)
left_context_key = torch.zeros(
self.left_context_length, batch_size, self.input_dim, device=device
)
left_context_val = torch.zeros(
self.left_context_length, batch_size, self.input_dim, device=device
)
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
return [empty_memory, left_context_key, left_context_val, past_length] return [empty_memory, left_context_key, left_context_val, past_length]
...@@ -453,12 +391,10 @@ class _EmformerLayer(torch.nn.Module): ...@@ -453,12 +391,10 @@ class _EmformerLayer(torch.nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
past_length = state[3][0][0].item() past_length = state[3][0][0].item()
past_left_context_length = min(self.left_context_length, past_length) past_left_context_length = min(self.left_context_length, past_length)
past_mem_length = min( past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
self.max_memory_size, math.ceil(past_length / self.segment_length) pre_mems = state[0][self.max_memory_size - past_mem_length :]
) lc_key = state[1][self.left_context_length - past_left_context_length :]
pre_mems = state[0][self.max_memory_size - past_mem_length:] lc_val = state[2][self.left_context_length - past_left_context_length :]
lc_key = state[1][self.left_context_length - past_left_context_length:]
lc_val = state[2][self.left_context_length - past_left_context_length:]
return pre_mems, lc_key, lc_val return pre_mems, lc_key, lc_val
def _pack_state( def _pack_state(
...@@ -471,9 +407,9 @@ class _EmformerLayer(torch.nn.Module): ...@@ -471,9 +407,9 @@ class _EmformerLayer(torch.nn.Module):
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
new_k = torch.cat([state[1], next_k]) new_k = torch.cat([state[1], next_k])
new_v = torch.cat([state[2], next_v]) new_v = torch.cat([state[2], next_v])
state[0] = torch.cat([state[0], mems])[-self.max_memory_size:] state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
state[1] = new_k[new_k.shape[0] - self.left_context_length:] state[1] = new_k[new_k.shape[0] - self.left_context_length :]
state[2] = new_v[new_v.shape[0] - self.left_context_length:] state[2] = new_v[new_v.shape[0] - self.left_context_length :]
state[3] = state[3] + update_length state[3] = state[3] + update_length
return state return state
...@@ -493,7 +429,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -493,7 +429,7 @@ class _EmformerLayer(torch.nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance])) layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
return ( return (
layer_norm_input[right_context.size(0):], layer_norm_input[right_context.size(0) :],
layer_norm_input[: right_context.size(0)], layer_norm_input[: right_context.size(0)],
) )
...@@ -501,7 +437,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -501,7 +437,7 @@ class _EmformerLayer(torch.nn.Module):
self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
rc_output = self._process_attention_output(rc_output, utterance, right_context) rc_output = self._process_attention_output(rc_output, utterance, right_context)
return rc_output[right_context.size(0):], rc_output[: right_context.size(0)] return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
def _apply_attention_forward( def _apply_attention_forward(
self, self,
...@@ -512,9 +448,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -512,9 +448,7 @@ class _EmformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if attention_mask is None: if attention_mask is None:
raise ValueError( raise ValueError("attention_mask must be not None when for_inference is False")
"attention_mask must be not None when for_inference is False"
)
if self.use_mem: if self.use_mem:
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
...@@ -602,9 +536,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -602,9 +536,7 @@ class _EmformerLayer(torch.nn.Module):
mems, mems,
attention_mask, attention_mask,
) )
output_utterance, output_right_context = self._apply_post_attention_ffn( output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
rc_output, utterance, right_context
)
return output_utterance, output_right_context, output_mems return output_utterance, output_right_context, output_mems
@torch.jit.export @torch.jit.export
...@@ -652,9 +584,7 @@ class _EmformerLayer(torch.nn.Module): ...@@ -652,9 +584,7 @@ class _EmformerLayer(torch.nn.Module):
rc_output, output_mems, output_state = self._apply_attention_infer( rc_output, output_mems, output_state = self._apply_attention_infer(
layer_norm_utterance, lengths, layer_norm_right_context, mems, state layer_norm_utterance, lengths, layer_norm_right_context, mems, state
) )
output_utterance, output_right_context = self._apply_post_attention_ffn( output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
rc_output, utterance, right_context
)
return output_utterance, output_right_context, output_state, output_mems return output_utterance, output_right_context, output_state, output_mems
...@@ -708,12 +638,12 @@ class Emformer(torch.nn.Module): ...@@ -708,12 +638,12 @@ class Emformer(torch.nn.Module):
self.use_mem = max_memory_size > 0 self.use_mem = max_memory_size > 0
self.memory_op = torch.nn.AvgPool1d( self.memory_op = torch.nn.AvgPool1d(
kernel_size=segment_length, stride=segment_length, ceil_mode=True, kernel_size=segment_length,
stride=segment_length,
ceil_mode=True,
) )
weight_init_gains = _get_weight_init_gains( weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
weight_init_scale_strategy, num_layers
)
self.emformer_layers = torch.nn.ModuleList( self.emformer_layers = torch.nn.ModuleList(
[ [
_EmformerLayer( _EmformerLayer(
...@@ -747,12 +677,10 @@ class Emformer(torch.nn.Module): ...@@ -747,12 +677,10 @@ class Emformer(torch.nn.Module):
start = (seg_idx + 1) * self.segment_length start = (seg_idx + 1) * self.segment_length
end = start + self.right_context_length end = start + self.right_context_length
right_context_blocks.append(input[start:end]) right_context_blocks.append(input[start:end])
right_context_blocks.append(input[T - self.right_context_length:]) right_context_blocks.append(input[T - self.right_context_length :])
return torch.cat(right_context_blocks) return torch.cat(right_context_blocks)
def _gen_attention_mask_col_widths( def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
self, seg_idx: int, utterance_length: int
) -> List[int]:
num_segs = math.ceil(utterance_length / self.segment_length) num_segs = math.ceil(utterance_length / self.segment_length)
rc = self.right_context_length rc = self.right_context_length
lc = self.left_context_length lc = self.left_context_length
...@@ -830,19 +758,13 @@ class Emformer(torch.nn.Module): ...@@ -830,19 +758,13 @@ class Emformer(torch.nn.Module):
query_mask.append(query_mask_block) query_mask.append(query_mask_block)
if s_cols_mask is not None: if s_cols_mask is not None:
summary_mask_block = _gen_attention_mask_block( summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
col_widths, s_cols_mask, 1, input.device
)
summary_mask.append(summary_mask_block) summary_mask.append(summary_mask_block)
attention_mask = ( attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
).to(torch.bool)
return attention_mask return attention_mask
def forward( def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training. r"""Forward pass for training.
B: batch size; B: batch size;
...@@ -874,9 +796,7 @@ class Emformer(torch.nn.Module): ...@@ -874,9 +796,7 @@ class Emformer(torch.nn.Module):
) )
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context, mems = layer( output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
output, lengths, right_context, mems, attention_mask
)
return output.permute(1, 0, 2), lengths return output.permute(1, 0, 2), lengths
@torch.jit.export @torch.jit.export
......
...@@ -15,13 +15,12 @@ class _TimeReduction(torch.nn.Module): ...@@ -15,13 +15,12 @@ class _TimeReduction(torch.nn.Module):
Args: Args:
stride (int): number of frames to merge for each output frame. stride (int): number of frames to merge for each output frame.
""" """
def __init__(self, stride: int) -> None: def __init__(self, stride: int) -> None:
super().__init__() super().__init__()
self.stride = stride self.stride = stride
def forward( def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass. r"""Forward pass.
B: batch size; B: batch size;
...@@ -64,6 +63,7 @@ class _CustomLSTM(torch.nn.Module): ...@@ -64,6 +63,7 @@ class _CustomLSTM(torch.nn.Module):
layer_norm_epsilon (float, optional): value of epsilon to use in layer_norm_epsilon (float, optional): value of epsilon to use in
layer normalization layers (Default: 1e-5) layer normalization layers (Default: 1e-5)
""" """
def __init__( def __init__(
self, self,
input_dim: int, input_dim: int,
...@@ -179,7 +179,9 @@ class _Transcriber(torch.nn.Module): ...@@ -179,7 +179,9 @@ class _Transcriber(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.input_linear = torch.nn.Linear( self.input_linear = torch.nn.Linear(
input_dim, time_reduction_input_dim, bias=False, input_dim,
time_reduction_input_dim,
bias=False,
) )
self.time_reduction = _TimeReduction(time_reduction_stride) self.time_reduction = _TimeReduction(time_reduction_stride)
transformer_input_dim = time_reduction_input_dim * time_reduction_stride transformer_input_dim = time_reduction_input_dim * time_reduction_stride
...@@ -200,9 +202,7 @@ class _Transcriber(torch.nn.Module): ...@@ -200,9 +202,7 @@ class _Transcriber(torch.nn.Module):
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim) self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim) self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward( def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training. r"""Forward pass for training.
B: batch size; B: batch size;
...@@ -225,12 +225,8 @@ class _Transcriber(torch.nn.Module): ...@@ -225,12 +225,8 @@ class _Transcriber(torch.nn.Module):
number of valid elements for i-th batch element in output frame sequences. number of valid elements for i-th batch element in output frame sequences.
""" """
input_linear_out = self.input_linear(input) input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction( time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
input_linear_out, lengths transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
)
transformer_out, transformer_lengths = self.transformer(
time_reduction_out, time_reduction_lengths
)
output_linear_out = self.output_linear(transformer_out) output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out) layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths return layer_norm_out, transformer_lengths
...@@ -271,9 +267,7 @@ class _Transcriber(torch.nn.Module): ...@@ -271,9 +267,7 @@ class _Transcriber(torch.nn.Module):
of ``infer``. of ``infer``.
""" """
input_linear_out = self.input_linear(input) input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction( time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
input_linear_out, lengths
)
( (
transformer_out, transformer_out,
transformer_lengths, transformer_lengths,
...@@ -299,6 +293,7 @@ class _Predictor(torch.nn.Module): ...@@ -299,6 +293,7 @@ class _Predictor(torch.nn.Module):
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0) lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
""" """
def __init__( def __init__(
self, self,
num_symbols: int, num_symbols: int,
...@@ -368,9 +363,7 @@ class _Predictor(torch.nn.Module): ...@@ -368,9 +363,7 @@ class _Predictor(torch.nn.Module):
lstm_out = input_layer_norm_out lstm_out = input_layer_norm_out
state_out: List[List[torch.Tensor]] = [] state_out: List[List[torch.Tensor]] = []
for layer_idx, lstm in enumerate(self.lstm_layers): for layer_idx, lstm in enumerate(self.lstm_layers):
lstm_out, lstm_state_out = lstm( lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
lstm_out, None if state is None else state[layer_idx]
)
lstm_out = self.dropout(lstm_out) lstm_out = self.dropout(lstm_out)
state_out.append(lstm_state_out) state_out.append(lstm_state_out)
...@@ -426,10 +419,7 @@ class _Joiner(torch.nn.Module): ...@@ -426,10 +419,7 @@ class _Joiner(torch.nn.Module):
output target lengths, with shape `(B,)` and i-th element representing output target lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 2 for i-th batch element in joint network output. number of valid elements along dim 2 for i-th batch element in joint network output.
""" """
joint_encodings = ( joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
source_encodings.unsqueeze(2).contiguous()
+ target_encodings.unsqueeze(1).contiguous()
)
relu_out = self.relu(joint_encodings) relu_out = self.relu(joint_encodings)
output = self.linear(relu_out) output = self.linear(relu_out)
return output, source_lengths, target_lengths return output, source_lengths, target_lengths
...@@ -447,9 +437,7 @@ class RNNT(torch.nn.Module): ...@@ -447,9 +437,7 @@ class RNNT(torch.nn.Module):
joiner (torch.nn.Module): joint network. joiner (torch.nn.Module): joint network.
""" """
def __init__( def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner
) -> None:
super().__init__() super().__init__()
self.transcriber = transcriber self.transcriber = transcriber
self.predictor = predictor self.predictor = predictor
...@@ -500,10 +488,13 @@ class RNNT(torch.nn.Module): ...@@ -500,10 +488,13 @@ class RNNT(torch.nn.Module):
of ``forward``. of ``forward``.
""" """
source_encodings, source_lengths = self.transcriber( source_encodings, source_lengths = self.transcriber(
input=sources, lengths=source_lengths, input=sources,
lengths=source_lengths,
) )
target_encodings, target_lengths, predictor_state = self.predictor( target_encodings, target_lengths, predictor_state = self.predictor(
input=targets, lengths=target_lengths, state=predictor_state, input=targets,
lengths=target_lengths,
state=predictor_state,
) )
output, source_lengths, target_lengths = self.joiner( output, source_lengths, target_lengths = self.joiner(
source_encodings=source_encodings, source_encodings=source_encodings,
...@@ -558,7 +549,9 @@ class RNNT(torch.nn.Module): ...@@ -558,7 +549,9 @@ class RNNT(torch.nn.Module):
@torch.jit.export @torch.jit.export
def transcribe( def transcribe(
self, sources: torch.Tensor, source_lengths: torch.Tensor, self,
sources: torch.Tensor,
source_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Applies transcription network to sources in non-streaming mode. r"""Applies transcription network to sources in non-streaming mode.
......
...@@ -35,21 +35,14 @@ def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: ...@@ -35,21 +35,14 @@ def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
for i in range(len(hypos[0].state)): for i in range(len(hypos[0].state)):
batched_state_components: List[torch.Tensor] = [] batched_state_components: List[torch.Tensor] = []
for j in range(len(hypos[0].state[i])): for j in range(len(hypos[0].state[i])):
batched_state_components.append( batched_state_components.append(torch.cat([hypo.state[i][j] for hypo in hypos]))
torch.cat([hypo.state[i][j] for hypo in hypos])
)
states.append(batched_state_components) states.append(batched_state_components)
return states return states
def _slice_state( def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]], idx: int, device: torch.device
) -> List[List[torch.Tensor]]:
idx_tensor = torch.tensor([idx], device=device) idx_tensor = torch.tensor([idx], device=device)
return [ return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
[state.index_select(0, idx_tensor) for state in state_tuple]
for state_tuple in states
]
def _default_hypo_sort_key(hypo: Hypothesis) -> float: def _default_hypo_sort_key(hypo: Hypothesis) -> float:
...@@ -57,18 +50,14 @@ def _default_hypo_sort_key(hypo: Hypothesis) -> float: ...@@ -57,18 +50,14 @@ def _default_hypo_sort_key(hypo: Hypothesis) -> float:
def _compute_updated_scores( def _compute_updated_scores(
hypos: List[Hypothesis], next_token_probs: torch.Tensor, beam_width: int, hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1) hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1)
nonblank_scores = ( nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
hypo_scores + next_token_probs[:, :-1] nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
) # [beam_width, num_tokens - 1] nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(
beam_width
)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(
nonblank_scores.shape[1], rounding_mode="trunc"
)
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1] nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
...@@ -114,9 +103,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -114,9 +103,7 @@ class RNNTBeamSearch(torch.nn.Module):
self.step_max_tokens = step_max_tokens self.step_max_tokens = step_max_tokens
def _init_b_hypos( def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
self, hypo: Optional[Hypothesis], device: torch.device
) -> List[Hypothesis]:
if hypo is not None: if hypo is not None:
token = hypo.tokens[-1] token = hypo.tokens[-1]
state = hypo.state state = hypo.state
...@@ -125,9 +112,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -125,9 +112,7 @@ class RNNTBeamSearch(torch.nn.Module):
state = None state = None
one_tensor = torch.tensor([1], device=device) one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict( pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
torch.tensor([[token]], device=device), one_tensor, state
)
init_hypo = Hypothesis( init_hypo = Hypothesis(
tokens=[token], tokens=[token],
predictor_out=pred_out[0].detach(), predictor_out=pred_out[0].detach(),
...@@ -150,9 +135,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -150,9 +135,7 @@ class RNNTBeamSearch(torch.nn.Module):
predictor_out, predictor_out,
torch.tensor([1] * len(hypos), device=device), torch.tensor([1] * len(hypos), device=device),
) # [beam_width, 1, 1, num_tokens] ) # [beam_width, 1, 1, num_tokens]
joined_out = torch.nn.functional.log_softmax( joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
joined_out / self.temperature, dim=3
)
joined_out[:, :, :, :4].add_(-99999) # blank out invalid tokens joined_out[:, :, :, :4].add_(-99999) # blank out invalid tokens
return joined_out[:, 0, 0] return joined_out[:, 0, 0]
...@@ -220,9 +203,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -220,9 +203,7 @@ class RNNTBeamSearch(torch.nn.Module):
new_scores.append(score) new_scores.append(score)
if base_hypos: if base_hypos:
new_hypos = self._gen_new_hypos( new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
base_hypos, new_tokens, new_scores, t, device
)
else: else:
new_hypos: List[Hypothesis] = [] new_hypos: List[Hypothesis] = []
...@@ -239,7 +220,9 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -239,7 +220,9 @@ class RNNTBeamSearch(torch.nn.Module):
tgt_tokens = torch.tensor([[token] for token in tokens], device=device) tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
states = _batch_state(base_hypos) states = _batch_state(base_hypos)
pred_out, _, pred_states = self.model.predict( pred_out, _, pred_states = self.model.predict(
tgt_tokens, torch.tensor([1] * len(base_hypos), device=device), states, tgt_tokens,
torch.tensor([1] * len(base_hypos), device=device),
states,
) )
new_hypos: List[Hypothesis] = [] new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos): for i, h_a in enumerate(base_hypos):
...@@ -258,7 +241,10 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -258,7 +241,10 @@ class RNNTBeamSearch(torch.nn.Module):
return new_hypos return new_hypos
def _search( def _search(
self, enc_out: torch.Tensor, hypo: Optional[Hypothesis], beam_width: int, self,
enc_out: torch.Tensor,
hypo: Optional[Hypothesis],
beam_width: int,
) -> List[Hypothesis]: ) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1] n_time_steps = enc_out.shape[1]
device = enc_out.device device = enc_out.device
...@@ -272,33 +258,35 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -272,33 +258,35 @@ class RNNTBeamSearch(torch.nn.Module):
symbols_current_t = 0 symbols_current_t = 0
while a_hypos: while a_hypos:
next_token_probs = self._gen_next_token_probs( next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
enc_out[:, t: t + 1], a_hypos, device
)
next_token_probs = next_token_probs.cpu() next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos( b_hypos = self._gen_b_hypos(
b_hypos, a_hypos, next_token_probs, key_to_b_hypo, b_hypos,
a_hypos,
next_token_probs,
key_to_b_hypo,
) )
if symbols_current_t == self.step_max_tokens: if symbols_current_t == self.step_max_tokens:
break break
a_hypos = self._gen_a_hypos( a_hypos = self._gen_a_hypos(
a_hypos, b_hypos, next_token_probs, t, beam_width, device, a_hypos,
b_hypos,
next_token_probs,
t,
beam_width,
device,
) )
if a_hypos: if a_hypos:
symbols_current_t += 1 symbols_current_t += 1
_, sorted_idx = torch.tensor( _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
[self.hypo_sort_key(hypo) for hypo in b_hypos]
).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx] b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos return b_hypos
def forward( def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
self, input: torch.Tensor, length: torch.Tensor, beam_width: int
) -> List[Hypothesis]:
r"""Performs beam search for the given input sequence. r"""Performs beam search for the given input sequence.
T: number of frames; T: number of frames;
...@@ -319,9 +307,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -319,9 +307,7 @@ class RNNTBeamSearch(torch.nn.Module):
if input.dim() == 2: if input.dim() == 2:
input = input.unsqueeze(0) input = input.unsqueeze(0)
assert length.shape == () or length.shape == ( assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
1,
), "length must be of shape () or (1,)"
if input.dim() == 0: if input.dim() == 0:
input = input.unsqueeze(0) input = input.unsqueeze(0)
...@@ -367,9 +353,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -367,9 +353,7 @@ class RNNTBeamSearch(torch.nn.Module):
if input.dim() == 2: if input.dim() == 2:
input = input.unsqueeze(0) input = input.unsqueeze(0)
assert length.shape == () or length.shape == ( assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
1,
), "length must be of shape () or (1,)"
if input.dim() == 0: if input.dim() == 0:
input = input.unsqueeze(0) input = input.unsqueeze(0)
......
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from .sox_effects import ( from .sox_effects import (
init_sox_effects, init_sox_effects,
shutdown_sox_effects, shutdown_sox_effects,
...@@ -10,13 +11,14 @@ from .sox_effects import ( ...@@ -10,13 +11,14 @@ from .sox_effects import (
if _mod_utils.is_sox_available(): if _mod_utils.is_sox_available():
import atexit import atexit
init_sox_effects() init_sox_effects()
atexit.register(shutdown_sox_effects) atexit.register(shutdown_sox_effects)
__all__ = [ __all__ = [
'init_sox_effects', "init_sox_effects",
'shutdown_sox_effects', "shutdown_sox_effects",
'effect_names', "effect_names",
'apply_effects_tensor', "apply_effects_tensor",
'apply_effects_file', "apply_effects_file",
] ]
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import list_effects from torchaudio.utils.sox_utils import list_effects
...@@ -53,10 +52,10 @@ def effect_names() -> List[str]: ...@@ -53,10 +52,10 @@ def effect_names() -> List[str]:
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def apply_effects_tensor( def apply_effects_tensor(
tensor: torch.Tensor, tensor: torch.Tensor,
sample_rate: int, sample_rate: int,
effects: List[List[str]], effects: List[List[str]],
channels_first: bool = True, channels_first: bool = True,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
"""Apply sox effects to given Tensor """Apply sox effects to given Tensor
...@@ -149,17 +148,16 @@ def apply_effects_tensor( ...@@ -149,17 +148,16 @@ def apply_effects_tensor(
>>> waveform, sample_rate = transform(waveform, input_sample_rate) >>> waveform, sample_rate = transform(waveform, input_sample_rate)
>>> assert sample_rate == 8000 >>> assert sample_rate == 8000
""" """
return torch.ops.torchaudio.sox_effects_apply_effects_tensor( return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first)
tensor, sample_rate, effects, channels_first)
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def apply_effects_file( def apply_effects_file(
path: str, path: str,
effects: List[List[str]], effects: List[List[str]],
normalize: bool = True, normalize: bool = True,
channels_first: bool = True, channels_first: bool = True,
format: Optional[str] = None, format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
"""Apply sox effects to the audio file and load the resulting data as Tensor """Apply sox effects to the audio file and load the resulting data as Tensor
...@@ -265,9 +263,7 @@ def apply_effects_file( ...@@ -265,9 +263,7 @@ def apply_effects_file(
>>> pass >>> pass
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(path, 'read'): if hasattr(path, "read"):
return torchaudio._torchaudio.apply_effects_fileobj( return torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
path, effects, normalize, channels_first, format)
path = os.fspath(path) path = os.fspath(path)
return torch.ops.torchaudio.sox_effects_apply_effects_file( return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
path, effects, normalize, channels_first, format)
...@@ -14,31 +14,31 @@ from .functional.functional import ( ...@@ -14,31 +14,31 @@ from .functional.functional import (
) )
__all__ = [ __all__ = [
'Spectrogram', "Spectrogram",
'InverseSpectrogram', "InverseSpectrogram",
'GriffinLim', "GriffinLim",
'AmplitudeToDB', "AmplitudeToDB",
'MelScale', "MelScale",
'InverseMelScale', "InverseMelScale",
'MelSpectrogram', "MelSpectrogram",
'MFCC', "MFCC",
'LFCC', "LFCC",
'MuLawEncoding', "MuLawEncoding",
'MuLawDecoding', "MuLawDecoding",
'Resample', "Resample",
'TimeStretch', "TimeStretch",
'Fade', "Fade",
'FrequencyMasking', "FrequencyMasking",
'TimeMasking', "TimeMasking",
'SlidingWindowCmn', "SlidingWindowCmn",
'Vad', "Vad",
'SpectralCentroid', "SpectralCentroid",
'Vol', "Vol",
'ComputeDeltas', "ComputeDeltas",
'PitchShift', "PitchShift",
'RNNTLoss', "RNNTLoss",
'PSD', "PSD",
'MVDR', "MVDR",
] ]
...@@ -73,21 +73,23 @@ class Spectrogram(torch.nn.Module): ...@@ -73,21 +73,23 @@ class Spectrogram(torch.nn.Module):
>>> spectrogram = transform(waveform) >>> spectrogram = transform(waveform)
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
def __init__(self, def __init__(
n_fft: int = 400, self,
win_length: Optional[int] = None, n_fft: int = 400,
hop_length: Optional[int] = None, win_length: Optional[int] = None,
pad: int = 0, hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window, pad: int = 0,
power: Optional[float] = 2., window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False, power: Optional[float] = 2.0,
wkwargs: Optional[dict] = None, normalized: bool = False,
center: bool = True, wkwargs: Optional[dict] = None,
pad_mode: str = "reflect", center: bool = True,
onesided: bool = True, pad_mode: str = "reflect",
return_complex: Optional[bool] = None) -> None: onesided: bool = True,
return_complex: Optional[bool] = None,
) -> None:
super(Spectrogram, self).__init__() super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
...@@ -95,7 +97,7 @@ class Spectrogram(torch.nn.Module): ...@@ -95,7 +97,7 @@ class Spectrogram(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window) self.register_buffer("window", window)
self.pad = pad self.pad = pad
self.power = power self.power = power
self.normalized = normalized self.normalized = normalized
...@@ -162,19 +164,21 @@ class InverseSpectrogram(torch.nn.Module): ...@@ -162,19 +164,21 @@ class InverseSpectrogram(torch.nn.Module):
>>> transform = transforms.InverseSpectrogram(n_fft=512) >>> transform = transforms.InverseSpectrogram(n_fft=512)
>>> waveform = transform(spectrogram, length) >>> waveform = transform(spectrogram, length)
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
def __init__(self, def __init__(
n_fft: int = 400, self,
win_length: Optional[int] = None, n_fft: int = 400,
hop_length: Optional[int] = None, win_length: Optional[int] = None,
pad: int = 0, hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window, pad: int = 0,
normalized: bool = False, window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None, normalized: bool = False,
center: bool = True, wkwargs: Optional[dict] = None,
pad_mode: str = "reflect", center: bool = True,
onesided: bool = True) -> None: pad_mode: str = "reflect",
onesided: bool = True,
) -> None:
super(InverseSpectrogram, self).__init__() super(InverseSpectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
...@@ -182,7 +186,7 @@ class InverseSpectrogram(torch.nn.Module): ...@@ -182,7 +186,7 @@ class InverseSpectrogram(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window) self.register_buffer("window", window)
self.pad = pad self.pad = pad
self.normalized = normalized self.normalized = normalized
self.center = center self.center = center
...@@ -242,31 +246,32 @@ class GriffinLim(torch.nn.Module): ...@@ -242,31 +246,32 @@ class GriffinLim(torch.nn.Module):
>>> transform = transforms.GriffinLim(n_fft=512) >>> transform = transforms.GriffinLim(n_fft=512)
>>> waveform = transform(spectrogram) >>> waveform = transform(spectrogram)
""" """
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', __constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"]
'length', 'momentum', 'rand_init']
def __init__(
def __init__(self, self,
n_fft: int = 400, n_fft: int = 400,
n_iter: int = 32, n_iter: int = 32,
win_length: Optional[int] = None, win_length: Optional[int] = None,
hop_length: Optional[int] = None, hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window, window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2., power: float = 2.0,
wkwargs: Optional[dict] = None, wkwargs: Optional[dict] = None,
momentum: float = 0.99, momentum: float = 0.99,
length: Optional[int] = None, length: Optional[int] = None,
rand_init: bool = True) -> None: rand_init: bool = True,
) -> None:
super(GriffinLim, self).__init__() super(GriffinLim, self).__init__()
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum) assert momentum >= 0, "momentum={} < 0".format(momentum)
self.n_fft = n_fft self.n_fft = n_fft
self.n_iter = n_iter self.n_iter = n_iter
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window) self.register_buffer("window", window)
self.length = length self.length = length
self.power = power self.power = power
self.momentum = momentum / (1 + momentum) self.momentum = momentum / (1 + momentum)
...@@ -282,8 +287,18 @@ class GriffinLim(torch.nn.Module): ...@@ -282,8 +287,18 @@ class GriffinLim(torch.nn.Module):
Returns: Returns:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
""" """
return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power, return F.griffinlim(
self.n_iter, self.momentum, self.length, self.rand_init) specgram,
self.window,
self.n_fft,
self.hop_length,
self.win_length,
self.power,
self.n_iter,
self.momentum,
self.length,
self.rand_init,
)
class AmplitudeToDB(torch.nn.Module): class AmplitudeToDB(torch.nn.Module):
...@@ -299,15 +314,15 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -299,15 +314,15 @@ class AmplitudeToDB(torch.nn.Module):
top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable
number is 80. (Default: ``None``) number is 80. (Default: ``None``)
""" """
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] __constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"]
def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None: def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
super(AmplitudeToDB, self).__init__() super(AmplitudeToDB, self).__init__()
self.stype = stype self.stype = stype
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError("top_db must be positive value")
self.top_db = top_db self.top_db = top_db
self.multiplier = 10.0 if stype == 'power' else 20.0 self.multiplier = 10.0 if stype == "power" else 20.0
self.amin = 1e-10 self.amin = 1e-10
self.ref_value = 1.0 self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
...@@ -344,16 +359,18 @@ class MelScale(torch.nn.Module): ...@@ -344,16 +359,18 @@ class MelScale(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks. generate the filter banks.
""" """
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
def __init__(self, def __init__(
n_mels: int = 128, self,
sample_rate: int = 16000, n_mels: int = 128,
f_min: float = 0., sample_rate: int = 16000,
f_max: Optional[float] = None, f_min: float = 0.0,
n_stft: int = 201, f_max: Optional[float] = None,
norm: Optional[str] = None, n_stft: int = 201,
mel_scale: str = "htk") -> None: norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(MelScale, self).__init__() super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -362,11 +379,9 @@ class MelScale(torch.nn.Module): ...@@ -362,11 +379,9 @@ class MelScale(torch.nn.Module):
self.norm = norm self.norm = norm
self.mel_scale = mel_scale self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
fb = F.melscale_fbanks( fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.register_buffer("fb", fb)
self.mel_scale)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
r""" r"""
...@@ -404,21 +419,32 @@ class InverseMelScale(torch.nn.Module): ...@@ -404,21 +419,32 @@ class InverseMelScale(torch.nn.Module):
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
""" """
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', __constants__ = [
'tolerance_change', 'sgdargs'] "n_stft",
"n_mels",
def __init__(self, "sample_rate",
n_stft: int, "f_min",
n_mels: int = 128, "f_max",
sample_rate: int = 16000, "max_iter",
f_min: float = 0., "tolerance_loss",
f_max: Optional[float] = None, "tolerance_change",
max_iter: int = 100000, "sgdargs",
tolerance_loss: float = 1e-5, ]
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None, def __init__(
norm: Optional[str] = None, self,
mel_scale: str = "htk") -> None: n_stft: int,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.0,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(InverseMelScale, self).__init__() super(InverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -427,13 +453,12 @@ class InverseMelScale(torch.nn.Module): ...@@ -427,13 +453,12 @@ class InverseMelScale(torch.nn.Module):
self.max_iter = max_iter self.max_iter = max_iter
self.tolerance_loss = tolerance_loss self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9} self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
norm, mel_scale) self.register_buffer("fb", fb)
self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor: def forward(self, melspec: Tensor) -> Tensor:
r""" r"""
...@@ -452,12 +477,13 @@ class InverseMelScale(torch.nn.Module): ...@@ -452,12 +477,13 @@ class InverseMelScale(torch.nn.Module):
melspec = melspec.transpose(-1, -2) melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels assert self.n_mels == n_mels
specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True, specgram = torch.rand(
dtype=melspec.dtype, device=melspec.device) melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
)
optim = torch.optim.SGD([specgram], **self.sgdargs) optim = torch.optim.SGD([specgram], **self.sgdargs)
loss = float('inf') loss = float("inf")
for _ in range(self.max_iter): for _ in range(self.max_iter):
optim.zero_grad() optim.zero_grad()
diff = melspec - specgram.matmul(self.fb) diff = melspec - specgram.matmul(self.fb)
...@@ -527,26 +553,28 @@ class MelSpectrogram(torch.nn.Module): ...@@ -527,26 +553,28 @@ class MelSpectrogram(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks. generate the filter banks.
""" """
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"]
def __init__(self, def __init__(
sample_rate: int = 16000, self,
n_fft: int = 400, sample_rate: int = 16000,
win_length: Optional[int] = None, n_fft: int = 400,
hop_length: Optional[int] = None, win_length: Optional[int] = None,
f_min: float = 0., hop_length: Optional[int] = None,
f_max: Optional[float] = None, f_min: float = 0.0,
pad: int = 0, f_max: Optional[float] = None,
n_mels: int = 128, pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window, n_mels: int = 128,
power: float = 2., window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False, power: float = 2.0,
wkwargs: Optional[dict] = None, normalized: bool = False,
center: bool = True, wkwargs: Optional[dict] = None,
pad_mode: str = "reflect", center: bool = True,
onesided: bool = True, pad_mode: str = "reflect",
norm: Optional[str] = None, onesided: bool = True,
mel_scale: str = "htk") -> None: norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(MelSpectrogram, self).__init__() super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
...@@ -558,19 +586,21 @@ class MelSpectrogram(torch.nn.Module): ...@@ -558,19 +586,21 @@ class MelSpectrogram(torch.nn.Module):
self.n_mels = n_mels # number of mel frequency bins self.n_mels = n_mels # number of mel frequency bins
self.f_max = f_max self.f_max = f_max
self.f_min = f_min self.f_min = f_min
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, self.spectrogram = Spectrogram(
hop_length=self.hop_length, n_fft=self.n_fft,
pad=self.pad, window_fn=window_fn, power=self.power, win_length=self.win_length,
normalized=self.normalized, wkwargs=wkwargs, hop_length=self.hop_length,
center=center, pad_mode=pad_mode, onesided=onesided) pad=self.pad,
window_fn=window_fn,
power=self.power,
normalized=self.normalized,
wkwargs=wkwargs,
center=center,
pad_mode=pad_mode,
onesided=onesided,
)
self.mel_scale = MelScale( self.mel_scale = MelScale(
self.n_mels, self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm, mel_scale
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
norm,
mel_scale
) )
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -609,33 +639,35 @@ class MFCC(torch.nn.Module): ...@@ -609,33 +639,35 @@ class MFCC(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks. generate the filter banks.
""" """
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] __constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"]
def __init__(self, def __init__(
sample_rate: int = 16000, self,
n_mfcc: int = 40, sample_rate: int = 16000,
dct_type: int = 2, n_mfcc: int = 40,
norm: str = 'ortho', dct_type: int = 2,
log_mels: bool = False, norm: str = "ortho",
melkwargs: Optional[dict] = None) -> None: log_mels: bool = False,
melkwargs: Optional[dict] = None,
) -> None:
super(MFCC, self).__init__() super(MFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported: {}'.format(dct_type)) raise ValueError("DCT type not supported: {}".format(dct_type))
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_mfcc = n_mfcc self.n_mfcc = n_mfcc
self.dct_type = dct_type self.dct_type = dct_type
self.norm = norm self.norm = norm
self.top_db = 80.0 self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
melkwargs = melkwargs or {} melkwargs = melkwargs or {}
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
if self.n_mfcc > self.MelSpectrogram.n_mels: if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins') raise ValueError("Cannot select more MFCC coefficients than # mel bins")
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.register_buffer('dct_mat', dct_mat) self.register_buffer("dct_mat", dct_mat)
self.log_mels = log_mels self.log_mels = log_mels
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -685,22 +717,24 @@ class LFCC(torch.nn.Module): ...@@ -685,22 +717,24 @@ class LFCC(torch.nn.Module):
:py:func:`torchaudio.functional.linear_fbanks` - The function used to :py:func:`torchaudio.functional.linear_fbanks` - The function used to
generate the filter banks. generate the filter banks.
""" """
__constants__ = ['sample_rate', 'n_filter', 'n_lfcc', 'dct_type', 'top_db', 'log_lf'] __constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"]
def __init__(self, def __init__(
sample_rate: int = 16000, self,
n_filter: int = 128, sample_rate: int = 16000,
f_min: float = 0., n_filter: int = 128,
f_max: Optional[float] = None, f_min: float = 0.0,
n_lfcc: int = 40, f_max: Optional[float] = None,
dct_type: int = 2, n_lfcc: int = 40,
norm: str = 'ortho', dct_type: int = 2,
log_lf: bool = False, norm: str = "ortho",
speckwargs: Optional[dict] = None) -> None: log_lf: bool = False,
speckwargs: Optional[dict] = None,
) -> None:
super(LFCC, self).__init__() super(LFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported: {}'.format(dct_type)) raise ValueError("DCT type not supported: {}".format(dct_type))
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.f_min = f_min self.f_min = f_min
self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_max = f_max if f_max is not None else float(sample_rate // 2)
...@@ -709,13 +743,13 @@ class LFCC(torch.nn.Module): ...@@ -709,13 +743,13 @@ class LFCC(torch.nn.Module):
self.dct_type = dct_type self.dct_type = dct_type
self.norm = norm self.norm = norm
self.top_db = 80.0 self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
speckwargs = speckwargs or {} speckwargs = speckwargs or {}
self.Spectrogram = Spectrogram(**speckwargs) self.Spectrogram = Spectrogram(**speckwargs)
if self.n_lfcc > self.Spectrogram.n_fft: if self.n_lfcc > self.Spectrogram.n_fft:
raise ValueError('Cannot select more LFCC coefficients than # fft bins') raise ValueError("Cannot select more LFCC coefficients than # fft bins")
filter_mat = F.linear_fbanks( filter_mat = F.linear_fbanks(
n_freqs=self.Spectrogram.n_fft // 2 + 1, n_freqs=self.Spectrogram.n_fft // 2 + 1,
...@@ -727,7 +761,7 @@ class LFCC(torch.nn.Module): ...@@ -727,7 +761,7 @@ class LFCC(torch.nn.Module):
self.register_buffer("filter_mat", filter_mat) self.register_buffer("filter_mat", filter_mat)
dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm) dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
self.register_buffer('dct_mat', dct_mat) self.register_buffer("dct_mat", dct_mat)
self.log_lf = log_lf self.log_lf = log_lf
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -770,7 +804,7 @@ class MuLawEncoding(torch.nn.Module): ...@@ -770,7 +804,7 @@ class MuLawEncoding(torch.nn.Module):
>>> mulawtrans = transform(waveform) >>> mulawtrans = transform(waveform)
""" """
__constants__ = ['quantization_channels'] __constants__ = ["quantization_channels"]
def __init__(self, quantization_channels: int = 256) -> None: def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawEncoding, self).__init__() super(MuLawEncoding, self).__init__()
...@@ -802,7 +836,7 @@ class MuLawDecoding(torch.nn.Module): ...@@ -802,7 +836,7 @@ class MuLawDecoding(torch.nn.Module):
>>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512) >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512)
>>> mulawtrans = transform(waveform) >>> mulawtrans = transform(waveform)
""" """
__constants__ = ['quantization_channels'] __constants__ = ["quantization_channels"]
def __init__(self, quantization_channels: int = 256) -> None: def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawDecoding, self).__init__() super(MuLawDecoding, self).__init__()
...@@ -853,15 +887,15 @@ class Resample(torch.nn.Module): ...@@ -853,15 +887,15 @@ class Resample(torch.nn.Module):
""" """
def __init__( def __init__(
self, self,
orig_freq: int = 16000, orig_freq: int = 16000,
new_freq: int = 16000, new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation', resampling_method: str = "sinc_interpolation",
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
beta: Optional[float] = None, beta: Optional[float] = None,
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -875,10 +909,16 @@ class Resample(torch.nn.Module): ...@@ -875,10 +909,16 @@ class Resample(torch.nn.Module):
if self.orig_freq != self.new_freq: if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel( kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq, self.new_freq, self.gcd, self.orig_freq,
self.lowpass_filter_width, self.rolloff, self.new_freq,
self.resampling_method, beta, dtype=dtype) self.gcd,
self.register_buffer('kernel', kernel) self.lowpass_filter_width,
self.rolloff,
self.resampling_method,
beta,
dtype=dtype,
)
self.register_buffer("kernel", kernel)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -890,9 +930,7 @@ class Resample(torch.nn.Module): ...@@ -890,9 +930,7 @@ class Resample(torch.nn.Module):
""" """
if self.orig_freq == self.new_freq: if self.orig_freq == self.new_freq:
return waveform return waveform
return _apply_sinc_resample_kernel( return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, self.kernel, self.width)
waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
class ComputeDeltas(torch.nn.Module): class ComputeDeltas(torch.nn.Module):
...@@ -904,7 +942,7 @@ class ComputeDeltas(torch.nn.Module): ...@@ -904,7 +942,7 @@ class ComputeDeltas(torch.nn.Module):
win_length (int, optional): The window length used for computing delta. (Default: ``5``) win_length (int, optional): The window length used for computing delta. (Default: ``5``)
mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``) mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``)
""" """
__constants__ = ['win_length'] __constants__ = ["win_length"]
def __init__(self, win_length: int = 5, mode: str = "replicate") -> None: def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
super(ComputeDeltas, self).__init__() super(ComputeDeltas, self).__init__()
...@@ -954,19 +992,16 @@ class TimeStretch(torch.nn.Module): ...@@ -954,19 +992,16 @@ class TimeStretch(torch.nn.Module):
:alt: Spectrogram streched by 0.9 :alt: Spectrogram streched by 0.9
""" """
__constants__ = ['fixed_rate'] __constants__ = ["fixed_rate"]
def __init__(self, def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None:
hop_length: Optional[int] = None,
n_freq: int = 201,
fixed_rate: Optional[float] = None) -> None:
super(TimeStretch, self).__init__() super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2 n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2 hop_length = hop_length if hop_length is not None else n_fft // 2
self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None]) self.register_buffer("phase_advance", torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
r""" r"""
...@@ -983,8 +1018,7 @@ class TimeStretch(torch.nn.Module): ...@@ -983,8 +1018,7 @@ class TimeStretch(torch.nn.Module):
""" """
if overriding_rate is None: if overriding_rate is None:
if self.fixed_rate is None: if self.fixed_rate is None:
raise ValueError( raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")
"If no fixed_rate is specified, must pass a valid rate to the forward method.")
rate = self.fixed_rate rate = self.fixed_rate
else: else:
rate = overriding_rate rate = overriding_rate
...@@ -1007,10 +1041,7 @@ class Fade(torch.nn.Module): ...@@ -1007,10 +1041,7 @@ class Fade(torch.nn.Module):
>>> faded_waveform = transform(waveform) >>> faded_waveform = transform(waveform)
""" """
def __init__(self, def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:
fade_in_len: int = 0,
fade_out_len: int = 0,
fade_shape: str = "linear") -> None:
super(Fade, self).__init__() super(Fade, self).__init__()
self.fade_in_len = fade_in_len self.fade_in_len = fade_in_len
self.fade_out_len = fade_out_len self.fade_out_len = fade_out_len
...@@ -1026,11 +1057,7 @@ class Fade(torch.nn.Module): ...@@ -1026,11 +1057,7 @@ class Fade(torch.nn.Module):
""" """
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
device = waveform.device device = waveform.device
return ( return self._fade_in(waveform_length, device) * self._fade_out(waveform_length, device) * waveform
self._fade_in(waveform_length, device)
* self._fade_out(waveform_length, device)
* waveform
)
def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor: def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len, device=device) fade = torch.linspace(0, 1, self.fade_in_len, device=device)
...@@ -1043,7 +1070,7 @@ class Fade(torch.nn.Module): ...@@ -1043,7 +1070,7 @@ class Fade(torch.nn.Module):
fade = torch.pow(2, (fade - 1)) * fade fade = torch.pow(2, (fade - 1)) * fade
if self.fade_shape == "logarithmic": if self.fade_shape == "logarithmic":
fade = torch.log10(.1 + fade) + 1 fade = torch.log10(0.1 + fade) + 1
if self.fade_shape == "quarter_sine": if self.fade_shape == "quarter_sine":
fade = torch.sin(fade * math.pi / 2) fade = torch.sin(fade * math.pi / 2)
...@@ -1058,10 +1085,10 @@ class Fade(torch.nn.Module): ...@@ -1058,10 +1085,10 @@ class Fade(torch.nn.Module):
ones = torch.ones(waveform_length - self.fade_out_len, device=device) ones = torch.ones(waveform_length - self.fade_out_len, device=device)
if self.fade_shape == "linear": if self.fade_shape == "linear":
fade = - fade + 1 fade = -fade + 1
if self.fade_shape == "exponential": if self.fade_shape == "exponential":
fade = torch.pow(2, - fade) * (1 - fade) fade = torch.pow(2, -fade) * (1 - fade)
if self.fade_shape == "logarithmic": if self.fade_shape == "logarithmic":
fade = torch.log10(1.1 - fade) + 1 fade = torch.log10(1.1 - fade) + 1
...@@ -1084,7 +1111,7 @@ class _AxisMasking(torch.nn.Module): ...@@ -1084,7 +1111,7 @@ class _AxisMasking(torch.nn.Module):
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension. iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D. This option is applicable only when the input tensor is 4D.
""" """
__constants__ = ['mask_param', 'axis', 'iid_masks'] __constants__ = ["mask_param", "axis", "iid_masks"]
def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None: def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
...@@ -1093,7 +1120,7 @@ class _AxisMasking(torch.nn.Module): ...@@ -1093,7 +1120,7 @@ class _AxisMasking(torch.nn.Module):
self.axis = axis self.axis = axis
self.iid_masks = iid_masks self.iid_masks = iid_masks
def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor: def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
r""" r"""
Args: Args:
specgram (Tensor): Tensor of dimension `(..., freq, time)`. specgram (Tensor): Tensor of dimension `(..., freq, time)`.
...@@ -1180,12 +1207,12 @@ class Vol(torch.nn.Module): ...@@ -1180,12 +1207,12 @@ class Vol(torch.nn.Module):
gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``) gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
""" """
def __init__(self, gain: float, gain_type: str = 'amplitude'): def __init__(self, gain: float, gain_type: str = "amplitude"):
super(Vol, self).__init__() super(Vol, self).__init__()
self.gain = gain self.gain = gain
self.gain_type = gain_type self.gain_type = gain_type
if gain_type in ['amplitude', 'power'] and gain < 0: if gain_type in ["amplitude", "power"] and gain < 0:
raise ValueError("If gain_type = amplitude or power, gain must be positive.") raise ValueError("If gain_type = amplitude or power, gain must be positive.")
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -1221,11 +1248,9 @@ class SlidingWindowCmn(torch.nn.Module): ...@@ -1221,11 +1248,9 @@ class SlidingWindowCmn(torch.nn.Module):
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
""" """
def __init__(self, def __init__(
cmn_window: int = 600, self, cmn_window: int = 600, min_cmn_window: int = 100, center: bool = False, norm_vars: bool = False
min_cmn_window: int = 100, ) -> None:
center: bool = False,
norm_vars: bool = False) -> None:
super().__init__() super().__init__()
self.cmn_window = cmn_window self.cmn_window = cmn_window
self.min_cmn_window = min_cmn_window self.min_cmn_window = min_cmn_window
...@@ -1240,8 +1265,7 @@ class SlidingWindowCmn(torch.nn.Module): ...@@ -1240,8 +1265,7 @@ class SlidingWindowCmn(torch.nn.Module):
Returns: Returns:
Tensor: Tensor of spectrogram of dimension `(..., time, freq)`. Tensor: Tensor of spectrogram of dimension `(..., time, freq)`.
""" """
cmn_specgram = F.sliding_window_cmn( cmn_specgram = F.sliding_window_cmn(specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
return cmn_specgram return cmn_specgram
...@@ -1297,24 +1321,26 @@ class Vad(torch.nn.Module): ...@@ -1297,24 +1321,26 @@ class Vad(torch.nn.Module):
- http://sox.sourceforge.net/sox.html - http://sox.sourceforge.net/sox.html
""" """
def __init__(self, def __init__(
sample_rate: int, self,
trigger_level: float = 7.0, sample_rate: int,
trigger_time: float = 0.25, trigger_level: float = 7.0,
search_time: float = 1.0, trigger_time: float = 0.25,
allowed_gap: float = 0.25, search_time: float = 1.0,
pre_trigger_time: float = 0.0, allowed_gap: float = 0.25,
boot_time: float = .35, pre_trigger_time: float = 0.0,
noise_up_time: float = .1, boot_time: float = 0.35,
noise_down_time: float = .01, noise_up_time: float = 0.1,
noise_reduction_amount: float = 1.35, noise_down_time: float = 0.01,
measure_freq: float = 20.0, noise_reduction_amount: float = 1.35,
measure_duration: Optional[float] = None, measure_freq: float = 20.0,
measure_smooth_time: float = .4, measure_duration: Optional[float] = None,
hp_filter_freq: float = 50., measure_smooth_time: float = 0.4,
lp_filter_freq: float = 6000., hp_filter_freq: float = 50.0,
hp_lifter_freq: float = 150., lp_filter_freq: float = 6000.0,
lp_lifter_freq: float = 2000.) -> None: hp_lifter_freq: float = 150.0,
lp_lifter_freq: float = 2000.0,
) -> None:
super().__init__() super().__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -1386,23 +1412,25 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1386,23 +1412,25 @@ class SpectralCentroid(torch.nn.Module):
>>> transform = transforms.SpectralCentroid(sample_rate) >>> transform = transforms.SpectralCentroid(sample_rate)
>>> spectral_centroid = transform(waveform) # (channel, time) >>> spectral_centroid = transform(waveform) # (channel, time)
""" """
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad'] __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"]
def __init__(self, def __init__(
sample_rate: int, self,
n_fft: int = 400, sample_rate: int,
win_length: Optional[int] = None, n_fft: int = 400,
hop_length: Optional[int] = None, win_length: Optional[int] = None,
pad: int = 0, hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window, pad: int = 0,
wkwargs: Optional[dict] = None) -> None: window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(SpectralCentroid, self).__init__() super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window) self.register_buffer("window", window)
self.pad = pad self.pad = pad
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -1414,8 +1442,9 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1414,8 +1442,9 @@ class SpectralCentroid(torch.nn.Module):
Tensor: Spectral Centroid of size `(..., time)`. Tensor: Spectral Centroid of size `(..., time)`.
""" """
return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, return F.spectral_centroid(
self.win_length) waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, self.win_length
)
class PitchShift(torch.nn.Module): class PitchShift(torch.nn.Module):
...@@ -1438,17 +1467,19 @@ class PitchShift(torch.nn.Module): ...@@ -1438,17 +1467,19 @@ class PitchShift(torch.nn.Module):
>>> transform = transforms.PitchShift(sample_rate, 4) >>> transform = transforms.PitchShift(sample_rate, 4)
>>> waveform_shift = transform(waveform) # (channel, time) >>> waveform_shift = transform(waveform) # (channel, time)
""" """
__constants__ = ['sample_rate', 'n_steps', 'bins_per_octave', 'n_fft', 'win_length', 'hop_length'] __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]
def __init__(self, def __init__(
sample_rate: int, self,
n_steps: int, sample_rate: int,
bins_per_octave: int = 12, n_steps: int,
n_fft: int = 512, bins_per_octave: int = 12,
win_length: Optional[int] = None, n_fft: int = 512,
hop_length: Optional[int] = None, win_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window, hop_length: Optional[int] = None,
wkwargs: Optional[dict] = None) -> None: window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__() super(PitchShift, self).__init__()
self.n_steps = n_steps self.n_steps = n_steps
self.bins_per_octave = bins_per_octave self.bins_per_octave = bins_per_octave
...@@ -1457,7 +1488,7 @@ class PitchShift(torch.nn.Module): ...@@ -1457,7 +1488,7 @@ class PitchShift(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 4 self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window) self.register_buffer("window", window)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -1468,8 +1499,16 @@ class PitchShift(torch.nn.Module): ...@@ -1468,8 +1499,16 @@ class PitchShift(torch.nn.Module):
Tensor: The pitch-shifted audio of shape `(..., time)`. Tensor: The pitch-shifted audio of shape `(..., time)`.
""" """
return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft, return F.pitch_shift(
self.win_length, self.hop_length, self.window) waveform,
self.sample_rate,
self.n_steps,
self.bins_per_octave,
self.n_fft,
self.win_length,
self.hop_length,
self.window,
)
class RNNTLoss(torch.nn.Module): class RNNTLoss(torch.nn.Module):
...@@ -1506,7 +1545,7 @@ class RNNTLoss(torch.nn.Module): ...@@ -1506,7 +1545,7 @@ class RNNTLoss(torch.nn.Module):
def __init__( def __init__(
self, self,
blank: int = -1, blank: int = -1,
clamp: float = -1., clamp: float = -1.0,
reduction: str = "mean", reduction: str = "mean",
): ):
super().__init__() super().__init__()
...@@ -1532,15 +1571,7 @@ class RNNTLoss(torch.nn.Module): ...@@ -1532,15 +1571,7 @@ class RNNTLoss(torch.nn.Module):
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch), Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar. otherwise scalar.
""" """
return F.rnnt_loss( return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction)
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.reduction
)
def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor: def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
...@@ -1557,8 +1588,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch ...@@ -1557,8 +1588,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
torch.Tensor: trace of the input Tensor torch.Tensor: trace of the input Tensor
""" """
assert input.ndim >= 2, "The dimension of the tensor must be at least 2." assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2],\ assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
"The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1) return input.sum(dim=-1)
...@@ -1684,8 +1714,11 @@ class MVDR(torch.nn.Module): ...@@ -1684,8 +1714,11 @@ class MVDR(torch.nn.Module):
online: bool = False, online: bool = False,
): ):
super().__init__() super().__init__()
assert solution in ["ref_channel", "stv_evd", "stv_power"],\ assert solution in [
"Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]." "ref_channel",
"stv_evd",
"stv_power",
], "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
self.ref_channel = ref_channel self.ref_channel = ref_channel
self.solution = solution self.solution = solution
self.multi_mask = multi_mask self.multi_mask = multi_mask
...@@ -1698,10 +1731,10 @@ class MVDR(torch.nn.Module): ...@@ -1698,10 +1731,10 @@ class MVDR(torch.nn.Module):
psd_n: torch.Tensor = torch.zeros(1) psd_n: torch.Tensor = torch.zeros(1)
mask_sum_s: torch.Tensor = torch.zeros(1) mask_sum_s: torch.Tensor = torch.zeros(1)
mask_sum_n: torch.Tensor = torch.zeros(1) mask_sum_n: torch.Tensor = torch.zeros(1)
self.register_buffer('psd_s', psd_s) self.register_buffer("psd_s", psd_s)
self.register_buffer('psd_n', psd_n) self.register_buffer("psd_n", psd_n)
self.register_buffer('mask_sum_s', mask_sum_s) self.register_buffer("mask_sum_s", mask_sum_s)
self.register_buffer('mask_sum_n', mask_sum_n) self.register_buffer("mask_sum_n", mask_sum_n)
def _get_updated_mvdr_vector( def _get_updated_mvdr_vector(
self, self,
...@@ -1710,7 +1743,7 @@ class MVDR(torch.nn.Module): ...@@ -1710,7 +1743,7 @@ class MVDR(torch.nn.Module):
mask_s: torch.Tensor, mask_s: torch.Tensor,
mask_n: torch.Tensor, mask_n: torch.Tensor,
reference_vector: torch.Tensor, reference_vector: torch.Tensor,
solution: str = 'ref_channel', solution: str = "ref_channel",
diagonal_loading: bool = True, diagonal_loading: bool = True,
diag_eps: float = 1e-7, diag_eps: float = 1e-7,
eps: float = 1e-8, eps: float = 1e-8,
...@@ -1788,7 +1821,7 @@ class MVDR(torch.nn.Module): ...@@ -1788,7 +1821,7 @@ class MVDR(torch.nn.Module):
psd_s: torch.Tensor, psd_s: torch.Tensor,
psd_n: torch.Tensor, psd_n: torch.Tensor,
reference_vector: torch.Tensor, reference_vector: torch.Tensor,
solution: str = 'ref_channel', solution: str = "ref_channel",
diagonal_loading: bool = True, diagonal_loading: bool = True,
diag_eps: float = 1e-7, diag_eps: float = 1e-7,
eps: float = 1e-8, eps: float = 1e-8,
...@@ -1851,10 +1884,7 @@ class MVDR(torch.nn.Module): ...@@ -1851,10 +1884,7 @@ class MVDR(torch.nn.Module):
return stv return stv
def _get_steering_vector_power( def _get_steering_vector_power(
self, self, psd_s: torch.Tensor, psd_n: torch.Tensor, reference_vector: torch.Tensor
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
r"""Estimate the steering vector by the power method. r"""Estimate the steering vector by the power method.
...@@ -1876,11 +1906,7 @@ class MVDR(torch.nn.Module): ...@@ -1876,11 +1906,7 @@ class MVDR(torch.nn.Module):
stv = torch.matmul(psd_s, stv) stv = torch.matmul(psd_s, stv)
return stv return stv
def _apply_beamforming_vector( def _apply_beamforming_vector(self, specgram: torch.Tensor, beamform_vector: torch.Tensor) -> torch.Tensor:
self,
specgram: torch.Tensor,
beamform_vector: torch.Tensor
) -> torch.Tensor:
r"""Apply the beamforming weight to the noisy STFT r"""Apply the beamforming weight to the noisy STFT
Args: Args:
specgram (torch.tensor): multi-channel noisy STFT specgram (torch.tensor): multi-channel noisy STFT
...@@ -1896,12 +1922,7 @@ class MVDR(torch.nn.Module): ...@@ -1896,12 +1922,7 @@ class MVDR(torch.nn.Module):
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram]) specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram])
return specgram_enhanced return specgram_enhanced
def _tik_reg( def _tik_reg(self, mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
self,
mat: torch.Tensor,
reg: float = 1e-7,
eps: float = 1e-8
) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part). """Perform Tikhonov regularization (only modifying real part).
Args: Args:
mat (torch.Tensor): input matrix (..., channel, channel) mat (torch.Tensor): input matrix (..., channel, channel)
...@@ -1922,10 +1943,7 @@ class MVDR(torch.nn.Module): ...@@ -1922,10 +1943,7 @@ class MVDR(torch.nn.Module):
return mat return mat
def forward( def forward(
self, self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
specgram: torch.Tensor,
mask_s: torch.Tensor,
mask_n: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Perform MVDR beamforming. """Perform MVDR beamforming.
...@@ -1946,9 +1964,7 @@ class MVDR(torch.nn.Module): ...@@ -1946,9 +1964,7 @@ class MVDR(torch.nn.Module):
""" """
dtype = specgram.dtype dtype = specgram.dtype
if specgram.ndim < 3: if specgram.ndim < 3:
raise ValueError( raise ValueError(f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}")
f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}"
)
if not specgram.is_complex(): if not specgram.is_complex():
raise ValueError( raise ValueError(
f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\ f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
...@@ -1958,9 +1974,7 @@ class MVDR(torch.nn.Module): ...@@ -1958,9 +1974,7 @@ class MVDR(torch.nn.Module):
specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``. specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
if mask_n is None: if mask_n is None:
warnings.warn( warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
"``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``."
)
mask_n = 1 - mask_s mask_n = 1 - mask_s
shape = specgram.size() shape = specgram.size()
...@@ -1977,33 +1991,15 @@ class MVDR(torch.nn.Module): ...@@ -1977,33 +1991,15 @@ class MVDR(torch.nn.Module):
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel) psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel) psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
u = torch.zeros( u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
specgram.size()[:-2],
device=specgram.device,
dtype=torch.cdouble
) # (..., channel)
u[..., self.ref_channel].fill_(1) u[..., self.ref_channel].fill_(1)
if self.online: if self.online:
w_mvdr = self._get_updated_mvdr_vector( w_mvdr = self._get_updated_mvdr_vector(
psd_s, psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
psd_n,
mask_s,
mask_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
) )
else: else:
w_mvdr = self._get_mvdr_vector( w_mvdr = self._get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
psd_s,
psd_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
)
specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr) specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr)
......
from torchaudio._internal import module_utils as _mod_utils
from . import ( from . import (
sox_utils, sox_utils,
) )
from torchaudio._internal import module_utils as _mod_utils
if _mod_utils.is_sox_available(): if _mod_utils.is_sox_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment