Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
from .interface import Tacotron2TTSBundle
from .impl import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
from .interface import Tacotron2TTSBundle
__all__ = [
'Tacotron2TTSBundle',
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH',
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH',
'TACOTRON2_WAVERNN_CHAR_LJSPEECH',
'TACOTRON2_WAVERNN_PHONE_LJSPEECH',
"Tacotron2TTSBundle",
"TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
"TACOTRON2_WAVERNN_CHAR_LJSPEECH",
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
]
from dataclasses import dataclass
import re
from dataclasses import dataclass
from typing import Union, Optional, Dict, Any, Tuple, List
import torch
from torch import Tensor
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.functional import mu_law_decoding
from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.transforms import InverseMelScale, GriffinLim
from . import utils
from .interface import Tacotron2TTSBundle
__all__ = []
_BASE_URL = 'https://download.pytorch.org/torchaudio/models'
_BASE_URL = "https://download.pytorch.org/torchaudio/models"
################################################################################
......@@ -44,8 +44,7 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
super().__init__()
self._tokens = utils._get_phones()
self._mapping = {p: i for i, p in enumerate(self._tokens)}
self._phonemizer = utils._load_phonemizer(
'en_us_cmudict_forward.pt', dl_kwargs=dl_kwargs)
self._phonemizer = utils._load_phonemizer("en_us_cmudict_forward.pt", dl_kwargs=dl_kwargs)
self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
@property
......@@ -57,9 +56,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
texts = [texts]
indices = []
for phones in self._phonemizer(texts, lang='en_us'):
for phones in self._phonemizer(texts, lang="en_us"):
# '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
ret = [re.sub(r'[\[\]]', '', r) for r in re.findall(self._pattern, phones)]
ret = [re.sub(r"[\[\]]", "", r) for r in re.findall(self._pattern, phones)]
indices.append([self._mapping[p] for p in ret])
return utils._to_tensor(indices)
......@@ -68,12 +67,9 @@ class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
# Pipeline implementation - Vocoder
################################################################################
class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def __init__(
self,
model: WaveRNN,
min_level_db: Optional[float] = -100
):
def __init__(self, model: WaveRNN, min_level_db: Optional[float] = -100):
super().__init__()
self._sample_rate = 22050
self._model = model
......@@ -104,10 +100,10 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
n_stft=(1024 // 2 + 1),
n_mels=80,
sample_rate=self.sample_rate,
f_min=0.,
f_max=8000.,
f_min=0.0,
f_max=8000.0,
mel_scale="slaney",
norm='slaney',
norm="slaney",
)
self._griffin_lim = GriffinLim(
n_fft=1024,
......@@ -151,7 +147,7 @@ class _Tacotron2Mixin:
def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
model = Tacotron2(**self._tacotron2_params)
url = f'{_BASE_URL}/{self._tacotron2_path}'
url = f"{_BASE_URL}/{self._tacotron2_path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict)
......@@ -170,7 +166,7 @@ class _WaveRNNMixin:
def _get_wavernn(self, *, dl_kwargs=None):
model = WaveRNN(**self._wavernn_params)
url = f'{_BASE_URL}/{self._wavernn_path}'
url = f"{_BASE_URL}/{self._wavernn_path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict)
......@@ -214,11 +210,10 @@ class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneM
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_ljspeech.pth',
_tacotron2_path="tacotron2_english_characters_1500_epochs_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=38),
)
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = (
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts character-by-character.
......@@ -254,14 +249,13 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
""" # noqa: E501
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_ljspeech.pth',
_tacotron2_path="tacotron2_english_phonemes_1500_epochs_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=96),
)
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = (
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts based on phoneme.
......@@ -302,16 +296,15 @@ Example - "The examination and testimony of the experts enabled the Commission t
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
""" # noqa: E501
TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth',
_tacotron2_path="tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=38),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth',
_wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
_wavernn_params=utils._get_wrnn_params(),
)
TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = (
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts character-by-character.
......@@ -350,16 +343,15 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
""" # noqa: E501
TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth',
_tacotron2_path="tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth",
_tacotron2_params=utils._get_taco_params(n_symbols=96),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth',
_wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
_wavernn_params=utils._get_wrnn_params(),
)
TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = (
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts based on phoneme.
......@@ -403,4 +395,4 @@ Example - "The examination and testimony of the experts enabled the Commission t
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
""" # noqa: E501
......@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Optional
from torch import Tensor
from torchaudio.models import Tacotron2
......
import os
import logging
import os
import torch
from torchaudio._internal import (
download_url_to_file,
module_utils as _mod_utils,
......@@ -11,44 +10,44 @@ from torchaudio._internal import (
def _get_chars():
return (
'_',
'-',
'!',
"_",
"-",
"!",
"'",
'(',
')',
',',
'.',
':',
';',
'?',
' ',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
"(",
")",
",",
".",
":",
";",
"?",
" ",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
)
......@@ -149,7 +148,7 @@ def _get_phones():
"W",
"Y",
"Z",
"ZH"
"ZH",
)
......@@ -161,18 +160,18 @@ def _to_tensor(indices):
def _load_phonemizer(file, dl_kwargs):
if not _mod_utils.is_module_available('dp'):
raise RuntimeError('DeepPhonemizer is not installed. Please install it.')
if not _mod_utils.is_module_available("dp"):
raise RuntimeError("DeepPhonemizer is not installed. Please install it.")
from dp.phonemizer import Phonemizer
# By default, dp issues DEBUG level log.
logger = logging.getLogger('dp')
logger = logging.getLogger("dp")
orig_level = logger.level
logger.setLevel(logging.INFO)
try:
url = f'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}'
directory = os.path.join(torch.hub.get_dir(), 'checkpoints')
url = f"https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}"
directory = os.path.join(torch.hub.get_dir(), "checkpoints")
os.makedirs(directory, exist_ok=True)
path = os.path.join(directory, file)
if not os.path.exists(path):
......@@ -192,41 +191,41 @@ def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
def _get_taco_params(n_symbols):
return {
'mask_padding': False,
'n_mels': 80,
'n_frames_per_step': 1,
'symbol_embedding_dim': 512,
'encoder_embedding_dim': 512,
'encoder_n_convolution': 3,
'encoder_kernel_size': 5,
'decoder_rnn_dim': 1024,
'decoder_max_step': 2000,
'decoder_dropout': 0.1,
'decoder_early_stopping': True,
'attention_rnn_dim': 1024,
'attention_hidden_dim': 128,
'attention_location_n_filter': 32,
'attention_location_kernel_size': 31,
'attention_dropout': 0.1,
'prenet_dim': 256,
'postnet_n_convolution': 5,
'postnet_kernel_size': 5,
'postnet_embedding_dim': 512,
'gate_threshold': 0.5,
'n_symbol': n_symbols,
"mask_padding": False,
"n_mels": 80,
"n_frames_per_step": 1,
"symbol_embedding_dim": 512,
"encoder_embedding_dim": 512,
"encoder_n_convolution": 3,
"encoder_kernel_size": 5,
"decoder_rnn_dim": 1024,
"decoder_max_step": 2000,
"decoder_dropout": 0.1,
"decoder_early_stopping": True,
"attention_rnn_dim": 1024,
"attention_hidden_dim": 128,
"attention_location_n_filter": 32,
"attention_location_kernel_size": 31,
"attention_dropout": 0.1,
"prenet_dim": 256,
"postnet_n_convolution": 5,
"postnet_kernel_size": 5,
"postnet_embedding_dim": 512,
"gate_threshold": 0.5,
"n_symbol": n_symbols,
}
def _get_wrnn_params():
return {
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
"upsample_scales": [5, 5, 11],
"n_classes": 2 ** 8, # n_bits = 8
"hop_length": 275,
"n_res_block": 10,
"n_rnn": 512,
"n_fc": 512,
"kernel_size": 5,
"n_freq": 80,
"n_hidden": 128,
"n_output": 128,
}
......@@ -2,9 +2,9 @@ from dataclasses import dataclass
from typing import Dict, Tuple, Any
import torch
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model
from . import utils
......@@ -43,6 +43,7 @@ class Wav2Vec2Bundle:
>>> # Extract acoustic features
>>> features, _ = model.extract_features(waveform)
""" # noqa: E501
_path: str
_params: Dict[str, Any]
_sample_rate: float
......@@ -56,7 +57,7 @@ class Wav2Vec2Bundle:
return self._sample_rate
def _get_state_dict(self, dl_kwargs):
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
......@@ -120,13 +121,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501
_labels: Tuple[str]
_remove_aux_axis: Tuple[int] = (1, 2, 3)
def get_labels(
self,
*,
blank: str = '-',
self,
*,
blank: str = "-",
) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles)
......@@ -161,17 +163,17 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ['aux.weight', 'aux.bias']:
for key in ["aux.weight", "aux.bias"]:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict
WAV2VEC2_BASE = Wav2Vec2Bundle(
_path='wav2vec2_fairseq_base_ls960.pth',
_path="wav2vec2_fairseq_base_ls960.pth",
_params={
'extractor_mode': 'group_norm',
'extractor_conv_layer_config': [
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -180,19 +182,19 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 768,
'encoder_projection_dropout': 0.1,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 12,
'encoder_num_heads': 12,
'encoder_attention_dropout': 0.1,
'encoder_ff_interm_features': 3072,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
},
_sample_rate=16000,
......@@ -212,10 +214,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
_path='wav2vec2_fairseq_base_ls960_asr_ll10m.pth',
_path="wav2vec2_fairseq_base_ls960_asr_ll10m.pth",
_params={
'extractor_mode': 'group_norm',
'extractor_conv_layer_config': [
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -224,19 +226,19 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 768,
'encoder_projection_dropout': 0.1,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 12,
'encoder_num_heads': 12,
'encoder_attention_dropout': 0.1,
'encoder_ff_interm_features': 3072,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": 29,
},
_labels=utils._get_en_labels(),
......@@ -258,10 +260,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_base_ls960_asr_ls100.pth',
"wav2vec2_fairseq_base_ls960_asr_ls100.pth",
{
'extractor_mode': 'group_norm',
'extractor_conv_layer_config': [
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -270,19 +272,19 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 768,
'encoder_projection_dropout': 0.1,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 12,
'encoder_num_heads': 12,
'encoder_attention_dropout': 0.1,
'encoder_ff_interm_features': 3072,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": 29,
},
_labels=utils._get_en_labels(),
......@@ -304,7 +306,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_base_ls960_asr_ls960.pth',
"wav2vec2_fairseq_base_ls960_asr_ls960.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -349,7 +351,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_LARGE = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_ls960.pth',
"wav2vec2_fairseq_large_ls960.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -393,7 +395,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ll10m.pth',
"wav2vec2_fairseq_large_ls960_asr_ll10m.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -439,7 +441,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ls100.pth',
"wav2vec2_fairseq_large_ls960_asr_ls100.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -485,7 +487,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_ls960_asr_ls960.pth',
"wav2vec2_fairseq_large_ls960_asr_ls960.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -530,7 +532,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_lv60k.pth',
"wav2vec2_fairseq_large_lv60k.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
......@@ -574,7 +576,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ll10m.pth',
"wav2vec2_fairseq_large_lv60k_asr_ll10m.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
......@@ -620,7 +622,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ls100.pth',
"wav2vec2_fairseq_large_lv60k_asr_ls100.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
......@@ -666,7 +668,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
'wav2vec2_fairseq_large_lv60k_asr_ls960.pth',
"wav2vec2_fairseq_large_lv60k_asr_ls960.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
......@@ -713,7 +715,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
'wav2vec2_fairseq_large_xlsr53.pth',
"wav2vec2_fairseq_large_xlsr53.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
......@@ -760,10 +762,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
HUBERT_BASE = Wav2Vec2Bundle(
'hubert_fairseq_base_ls960.pth',
"hubert_fairseq_base_ls960.pth",
{
'extractor_mode': 'group_norm',
'extractor_conv_layer_config': [
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -772,20 +774,20 @@ HUBERT_BASE = Wav2Vec2Bundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 768,
'encoder_projection_dropout': 0.1,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 12,
'encoder_num_heads': 12,
'encoder_attention_dropout': 0.1,
'encoder_ff_interm_features': 3072,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
'aux_num_out': None,
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
},
_sample_rate=16000,
)
......@@ -804,10 +806,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
HUBERT_LARGE = Wav2Vec2Bundle(
'hubert_fairseq_large_ll60k.pth',
"hubert_fairseq_large_ll60k.pth",
{
'extractor_mode': 'layer_norm',
'extractor_conv_layer_config': [
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -816,20 +818,20 @@ HUBERT_LARGE = Wav2Vec2Bundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 1024,
'encoder_projection_dropout': 0.0,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 24,
'encoder_num_heads': 16,
'encoder_attention_dropout': 0.0,
'encoder_ff_interm_features': 4096,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.0,
'aux_num_out': None,
"extractor_conv_bias": False,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_sample_rate=16000,
)
......@@ -848,10 +850,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
HUBERT_XLARGE = Wav2Vec2Bundle(
'hubert_fairseq_xlarge_ll60k.pth',
"hubert_fairseq_xlarge_ll60k.pth",
{
'extractor_mode': 'layer_norm',
'extractor_conv_layer_config': [
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -860,20 +862,20 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 1280,
'encoder_projection_dropout': 0.0,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 48,
'encoder_num_heads': 16,
'encoder_attention_dropout': 0.0,
'encoder_ff_interm_features': 5120,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.0,
'aux_num_out': None,
"extractor_conv_bias": False,
"encoder_embed_dim": 1280,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 48,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 5120,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_sample_rate=16000,
)
......@@ -892,10 +894,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'hubert_fairseq_large_ll60k_asr_ls960.pth',
"hubert_fairseq_large_ll60k_asr_ls960.pth",
{
'extractor_mode': 'layer_norm',
'extractor_conv_layer_config': [
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -904,20 +906,20 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 1024,
'encoder_projection_dropout': 0.0,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 24,
'encoder_num_heads': 16,
'encoder_attention_dropout': 0.0,
'encoder_ff_interm_features': 4096,
'encoder_ff_interm_dropout': 0.1,
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 29,
"extractor_conv_bias": False,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.1,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.1,
"aux_num_out": 29,
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
......@@ -939,10 +941,10 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'hubert_fairseq_xlarge_ll60k_asr_ls960.pth',
"hubert_fairseq_xlarge_ll60k_asr_ls960.pth",
{
'extractor_mode': 'layer_norm',
'extractor_conv_layer_config': [
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
......@@ -951,20 +953,20 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 1280,
'encoder_projection_dropout': 0.0,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 48,
'encoder_num_heads': 16,
'encoder_attention_dropout': 0.0,
'encoder_ff_interm_features': 5120,
'encoder_ff_interm_dropout': 0.1,
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 29,
"extractor_conv_bias": False,
"encoder_embed_dim": 1280,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 48,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 5120,
"encoder_ff_interm_dropout": 0.1,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.1,
"aux_num_out": 29,
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
......@@ -987,7 +989,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_de.pt',
"wav2vec2_voxpopuli_base_10k_asr_de.pt",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -1034,7 +1036,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_en.pt',
"wav2vec2_voxpopuli_base_10k_asr_en.pt",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -1059,7 +1061,7 @@ VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1,
"aux_num_out": 28
"aux_num_out": 28,
},
_labels=utils._get_vp_en_labels(),
_sample_rate=16000,
......@@ -1081,7 +1083,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_es.pt',
"wav2vec2_voxpopuli_base_10k_asr_es.pt",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -1106,7 +1108,7 @@ VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1,
"aux_num_out": 35
"aux_num_out": 35,
},
_labels=utils._get_es_labels(),
_sample_rate=16000,
......@@ -1127,7 +1129,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501
VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_fr.pt',
"wav2vec2_voxpopuli_base_10k_asr_fr.pt",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......@@ -1152,7 +1154,7 @@ VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.1,
"aux_num_out": 43
"aux_num_out": 43,
},
_labels=utils._get_fr_labels(),
_sample_rate=16000,
......@@ -1173,7 +1175,7 @@ Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
'wav2vec2_voxpopuli_base_10k_asr_it.pt',
"wav2vec2_voxpopuli_base_10k_asr_it.pt",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
......
def _get_en_labels():
return (
'|',
'E',
'T',
'A',
'O',
'N',
'I',
'H',
'S',
'R',
'D',
'L',
'U',
'M',
'W',
'C',
'F',
'G',
'Y',
'P',
'B',
'V',
'K',
"|",
"E",
"T",
"A",
"O",
"N",
"I",
"H",
"S",
"R",
"D",
"L",
"U",
"M",
"W",
"C",
"F",
"G",
"Y",
"P",
"B",
"V",
"K",
"'",
'X',
'J',
'Q',
'Z',
"X",
"J",
"Q",
"Z",
)
......
import math
import torch
from typing import List, Optional, Tuple
import torch
__all__ = ["Conformer"]
......@@ -12,9 +13,9 @@ PADDING_IDX = 1
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0]
max_length = int(torch.max(lengths).item())
padding_mask = torch.arange(
max_length, device=lengths.device, dtype=lengths.dtype
).expand(batch_size, max_length) >= lengths.unsqueeze(1)
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
batch_size, max_length
) >= lengths.unsqueeze(1)
return padding_mask
......@@ -31,12 +32,8 @@ def _get_sinusoidal_embeddings(
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
t = (
torch.arange(half_dim, dtype=torch.float) * -math.log(10000) / (half_dim - 1)
).exp()
embedding_t = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * t.unsqueeze(0)
t = (torch.arange(half_dim, dtype=torch.float) * -math.log(10000) / (half_dim - 1)).exp()
embedding_t = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * t.unsqueeze(0)
embeddings = torch.cat([embedding_t.sin(), embedding_t.cos()], dim=1)
if embedding_dim % 2 == 1:
embeddings = torch.cat([embeddings, torch.zeros(num_embeddings, 1)], dim=1)
......@@ -64,13 +61,16 @@ class ConvolutionModule(torch.nn.Module):
dropout: float = 0.0,
) -> None:
super().__init__()
assert (
depthwise_kernel_size - 1
) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
self.layer_norm = torch.nn.LayerNorm(input_dim)
self.sequential = torch.nn.Sequential(
torch.nn.Conv1d(
input_dim, 2 * num_channels, 1, stride=1, padding=0, bias=bias,
input_dim,
2 * num_channels,
1,
stride=1,
padding=0,
bias=bias,
),
torch.nn.GLU(dim=1),
torch.nn.Conv1d(
......@@ -85,7 +85,12 @@ class ConvolutionModule(torch.nn.Module):
torch.nn.BatchNorm1d(num_channels),
torch.nn.SiLU(),
torch.nn.Conv1d(
num_channels, input_dim, kernel_size=1, stride=1, padding=0, bias=bias,
num_channels,
input_dim,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
),
torch.nn.Dropout(dropout),
)
......@@ -159,9 +164,7 @@ class ConformerLayer(torch.nn.Module):
self.ffn1 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
self.self_attn = torch.nn.MultiheadAttention(
input_dim, num_attention_heads, dropout=dropout
)
self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
self.self_attn_dropout = torch.nn.Dropout(dropout)
self.conv_module = ConvolutionModule(
......@@ -173,9 +176,7 @@ class ConformerLayer(torch.nn.Module):
self.ffn2 = FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
self.final_layer_norm = torch.nn.LayerNorm(input_dim)
def forward(
self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]
) -> torch.Tensor:
def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
r"""
Args:
input (torch.Tensor): input, with shape `(T, B, D)`.
......@@ -256,9 +257,7 @@ class Conv1dSubsampler(torch.nn.Module):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out.to(torch.int32)
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
input (torch.Tensor): input frames, with shape `(B, T_in, in_channels)`.
......@@ -289,15 +288,11 @@ class SinusoidalPositionalEmbedding(torch.nn.Module):
init_size (int, optional): initial embedding count. (Default: 1024)
"""
def __init__(
self, embedding_dim: int, padding_idx: int = 0, init_size: int = 1024
) -> None:
def __init__(self, embedding_dim: int, padding_idx: int = 0, init_size: int = 1024) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embeddings = _get_sinusoidal_embeddings(
init_size, embedding_dim, padding_idx
)
self.embeddings = _get_sinusoidal_embeddings(init_size, embedding_dim, padding_idx)
def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""
......@@ -310,14 +305,10 @@ class SinusoidalPositionalEmbedding(torch.nn.Module):
B, T = input.shape
max_pos = self.padding_idx + 1 + T
if max_pos > self.embeddings.size(0):
self.embeddings = _get_sinusoidal_embeddings(
max_pos, self.embedding_dim, self.padding_idx
)
self.embeddings = _get_sinusoidal_embeddings(max_pos, self.embedding_dim, self.padding_idx)
self.embeddings = self.embeddings.to(input)
positions = _make_positions(input, self.padding_idx)
return (
self.embeddings.index_select(0, positions.view(-1)).view(B, T, -1).detach()
)
return self.embeddings.index_select(0, positions.view(-1)).view(B, T, -1).detach()
class Conformer(torch.nn.Module):
......@@ -370,16 +361,17 @@ class Conformer(torch.nn.Module):
super().__init__()
self.subsample = Conv1dSubsampler(
input_dim, conv_channels, conformer_layer_input_dim, conv_kernel_sizes,
input_dim,
conv_channels,
conformer_layer_input_dim,
conv_kernel_sizes,
)
self.position_embedding = SinusoidalPositionalEmbedding(
conformer_layer_input_dim,
padding_idx=PADDING_IDX,
init_size=max_source_positions + PADDING_IDX + 1,
)
self.linear = torch.nn.Linear(
conformer_layer_input_dim, conformer_layer_input_dim
)
self.linear = torch.nn.Linear(conformer_layer_input_dim, conformer_layer_input_dim)
self.dropout = torch.nn.Dropout(dropout)
self.conformer_layers = torch.nn.ModuleList(
[
......@@ -394,9 +386,7 @@ class Conformer(torch.nn.Module):
]
)
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
input (torch.Tensor): with shape `(B, T_in, input_dim)`.
......
......@@ -10,9 +10,9 @@ __all__ = ["Emformer"]
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0]
max_length = int(torch.max(lengths).item())
padding_mask = torch.arange(
max_length, device=lengths.device, dtype=lengths.dtype
).expand(batch_size, max_length) >= lengths.unsqueeze(1)
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
batch_size, max_length
) >= lengths.unsqueeze(1)
return padding_mask
......@@ -30,15 +30,8 @@ def _gen_padding_mask(
padding_mask = None
else:
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
left_context_blocks_length = (
left_context_key.size(0) if left_context_key is not None else 0
)
klengths = (
lengths
+ mems.size(0)
+ right_context_blocks_length
+ left_context_blocks_length
)
left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
padding_mask = _lengths_to_padding_mask(lengths=klengths)
return padding_mask
......@@ -54,9 +47,7 @@ def _get_activation_module(activation: str) -> torch.nn.Module:
raise ValueError(f"Unsupported activation {activation}")
def _get_weight_init_gains(
weight_init_scale_strategy: Optional[str], num_layers: int
) -> List[Optional[float]]:
def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise":
......@@ -64,17 +55,13 @@ def _get_weight_init_gains(
elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else:
raise ValueError(
f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}"
)
raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
def _gen_attention_mask_block(
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
) -> torch.Tensor:
assert len(col_widths) == len(
col_mask
), "Length of col_widths must match that of col_mask"
assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask"
mask_block = [
torch.ones(num_rows, col_width, device=device)
......@@ -110,9 +97,7 @@ class _EmformerAttention(torch.nn.Module):
super().__init__()
if input_dim % num_heads != 0:
raise ValueError(
f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads})."
)
raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
self.input_dim = input_dim
self.num_heads = num_heads
......@@ -127,23 +112,15 @@ class _EmformerAttention(torch.nn.Module):
self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
if weight_init_gain:
torch.nn.init.xavier_uniform_(
self.emb_to_key_value.weight, gain=weight_init_gain
)
torch.nn.init.xavier_uniform_(
self.emb_to_query.weight, gain=weight_init_gain
)
torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
def _gen_key_value(
self, input: torch.Tensor, mems: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
T, _, _ = input.shape
summary_length = mems.size(0) + 1
right_ctx_utterance_block = input[: T - summary_length]
mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(
chunks=2, dim=2
)
key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
return key, value
def _gen_attention_probs(
......@@ -153,27 +130,17 @@ class _EmformerAttention(torch.nn.Module):
padding_mask: Optional[torch.Tensor],
) -> torch.Tensor:
attention_weights_float = attention_weights.float()
attention_weights_float = attention_weights_float.masked_fill(
attention_mask.unsqueeze(0), self.negative_inf
)
attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
T = attention_weights.size(1)
B = attention_weights.size(0) // self.num_heads
if padding_mask is not None:
attention_weights_float = attention_weights_float.view(
B, self.num_heads, T, -1
)
attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
attention_weights_float = attention_weights_float.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
)
attention_weights_float = attention_weights_float.view(
B * self.num_heads, T, -1
)
attention_probs = torch.nn.functional.softmax(
attention_weights_float, dim=-1
).type_as(attention_weights)
return torch.nn.functional.dropout(
attention_probs, p=float(self.dropout), training=self.training
)
attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
def _forward_impl(
self,
......@@ -193,9 +160,7 @@ class _EmformerAttention(torch.nn.Module):
query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
# Compute key and value with [mems, right context, utterance].
key, value = self.emb_to_key_value(
torch.cat([mems, right_context, utterance])
).chunk(chunks=2, dim=2)
key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
if left_context_key is not None and left_context_val is not None:
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
......@@ -203,37 +168,29 @@ class _EmformerAttention(torch.nn.Module):
[
key[: mems.size(0) + right_context_blocks_length],
left_context_key,
key[mems.size(0) + right_context_blocks_length:],
key[mems.size(0) + right_context_blocks_length :],
],
)
value = torch.cat(
[
value[: mems.size(0) + right_context_blocks_length],
left_context_val,
value[mems.size(0) + right_context_blocks_length:],
value[mems.size(0) + right_context_blocks_length :],
],
)
# Compute attention weights from query, key, and value.
reshaped_query, reshaped_key, reshaped_value = [
tensor.contiguous()
.view(-1, B * self.num_heads, self.input_dim // self.num_heads)
.transpose(0, 1)
tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
for tensor in [query, key, value]
]
attention_weights = torch.bmm(
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
)
attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
# Compute padding mask.
padding_mask = _gen_padding_mask(
utterance, right_context, summary, lengths, mems, left_context_key
)
padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
# Compute attention probabilities.
attention_probs = self._gen_attention_probs(
attention_weights, attention_mask, padding_mask
)
attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
# Compute attention.
attention = torch.bmm(attention_probs, reshaped_value)
......@@ -249,7 +206,7 @@ class _EmformerAttention(torch.nn.Module):
summary_length = summary.size(0)
output_right_context = output_right_context_mems[: T - summary_length]
output_mems = output_right_context_mems[T - summary_length:]
output_mems = output_right_context_mems[T - summary_length :]
if self.tanh_on_mem:
output_mems = torch.tanh(output_mems)
else:
......@@ -291,9 +248,7 @@ class _EmformerAttention(torch.nn.Module):
Tensor
updated memory elements, with shape `(M, B, D)`.
"""
output, output_mems, _, _ = self._forward_impl(
utterance, lengths, right_context, summary, mems, attention_mask
)
output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
return output, output_mems[:-1]
@torch.jit.export
......@@ -338,15 +293,8 @@ class _EmformerAttention(torch.nn.Module):
attention value computed for left context and utterance.
"""
query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
key_dim = (
right_context.size(0)
+ utterance.size(0)
+ mems.size(0)
+ left_context_key.size(0)
)
attention_mask = torch.zeros(query_dim, key_dim).to(
dtype=torch.bool, device=utterance.device
)
key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
attention_mask[-1, : mems.size(0)] = True
output, output_mems, key, value = self._forward_impl(
utterance,
......@@ -361,8 +309,8 @@ class _EmformerAttention(torch.nn.Module):
return (
output,
output_mems,
key[mems.size(0) + right_context.size(0):],
value[mems.size(0) + right_context.size(0):],
key[mems.size(0) + right_context.size(0) :],
value[mems.size(0) + right_context.size(0) :],
)
......@@ -410,9 +358,7 @@ class _EmformerLayer(torch.nn.Module):
negative_inf=negative_inf,
)
self.dropout = torch.nn.Dropout(dropout)
self.memory_op = torch.nn.AvgPool1d(
kernel_size=segment_length, stride=segment_length, ceil_mode=True
)
self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
activation_module = _get_activation_module(activation)
self.pos_ff = torch.nn.Sequential(
......@@ -433,18 +379,10 @@ class _EmformerLayer(torch.nn.Module):
self.use_mem = max_memory_size > 0
def _init_state(
self, batch_size: int, device: Optional[torch.device]
) -> List[torch.Tensor]:
empty_memory = torch.zeros(
self.max_memory_size, batch_size, self.input_dim, device=device
)
left_context_key = torch.zeros(
self.left_context_length, batch_size, self.input_dim, device=device
)
left_context_val = torch.zeros(
self.left_context_length, batch_size, self.input_dim, device=device
)
def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
return [empty_memory, left_context_key, left_context_val, past_length]
......@@ -453,12 +391,10 @@ class _EmformerLayer(torch.nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
past_length = state[3][0][0].item()
past_left_context_length = min(self.left_context_length, past_length)
past_mem_length = min(
self.max_memory_size, math.ceil(past_length / self.segment_length)
)
pre_mems = state[0][self.max_memory_size - past_mem_length:]
lc_key = state[1][self.left_context_length - past_left_context_length:]
lc_val = state[2][self.left_context_length - past_left_context_length:]
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
pre_mems = state[0][self.max_memory_size - past_mem_length :]
lc_key = state[1][self.left_context_length - past_left_context_length :]
lc_val = state[2][self.left_context_length - past_left_context_length :]
return pre_mems, lc_key, lc_val
def _pack_state(
......@@ -471,9 +407,9 @@ class _EmformerLayer(torch.nn.Module):
) -> List[torch.Tensor]:
new_k = torch.cat([state[1], next_k])
new_v = torch.cat([state[2], next_v])
state[0] = torch.cat([state[0], mems])[-self.max_memory_size:]
state[1] = new_k[new_k.shape[0] - self.left_context_length:]
state[2] = new_v[new_v.shape[0] - self.left_context_length:]
state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
state[1] = new_k[new_k.shape[0] - self.left_context_length :]
state[2] = new_v[new_v.shape[0] - self.left_context_length :]
state[3] = state[3] + update_length
return state
......@@ -493,7 +429,7 @@ class _EmformerLayer(torch.nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
return (
layer_norm_input[right_context.size(0):],
layer_norm_input[right_context.size(0) :],
layer_norm_input[: right_context.size(0)],
)
......@@ -501,7 +437,7 @@ class _EmformerLayer(torch.nn.Module):
self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
rc_output = self._process_attention_output(rc_output, utterance, right_context)
return rc_output[right_context.size(0):], rc_output[: right_context.size(0)]
return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
def _apply_attention_forward(
self,
......@@ -512,9 +448,7 @@ class _EmformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if attention_mask is None:
raise ValueError(
"attention_mask must be not None when for_inference is False"
)
raise ValueError("attention_mask must be not None when for_inference is False")
if self.use_mem:
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
......@@ -602,9 +536,7 @@ class _EmformerLayer(torch.nn.Module):
mems,
attention_mask,
)
output_utterance, output_right_context = self._apply_post_attention_ffn(
rc_output, utterance, right_context
)
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
return output_utterance, output_right_context, output_mems
@torch.jit.export
......@@ -652,9 +584,7 @@ class _EmformerLayer(torch.nn.Module):
rc_output, output_mems, output_state = self._apply_attention_infer(
layer_norm_utterance, lengths, layer_norm_right_context, mems, state
)
output_utterance, output_right_context = self._apply_post_attention_ffn(
rc_output, utterance, right_context
)
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
return output_utterance, output_right_context, output_state, output_mems
......@@ -708,12 +638,12 @@ class Emformer(torch.nn.Module):
self.use_mem = max_memory_size > 0
self.memory_op = torch.nn.AvgPool1d(
kernel_size=segment_length, stride=segment_length, ceil_mode=True,
kernel_size=segment_length,
stride=segment_length,
ceil_mode=True,
)
weight_init_gains = _get_weight_init_gains(
weight_init_scale_strategy, num_layers
)
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
self.emformer_layers = torch.nn.ModuleList(
[
_EmformerLayer(
......@@ -747,12 +677,10 @@ class Emformer(torch.nn.Module):
start = (seg_idx + 1) * self.segment_length
end = start + self.right_context_length
right_context_blocks.append(input[start:end])
right_context_blocks.append(input[T - self.right_context_length:])
right_context_blocks.append(input[T - self.right_context_length :])
return torch.cat(right_context_blocks)
def _gen_attention_mask_col_widths(
self, seg_idx: int, utterance_length: int
) -> List[int]:
def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
num_segs = math.ceil(utterance_length / self.segment_length)
rc = self.right_context_length
lc = self.left_context_length
......@@ -830,19 +758,13 @@ class Emformer(torch.nn.Module):
query_mask.append(query_mask_block)
if s_cols_mask is not None:
summary_mask_block = _gen_attention_mask_block(
col_widths, s_cols_mask, 1, input.device
)
summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
summary_mask.append(summary_mask_block)
attention_mask = (
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
).to(torch.bool)
attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
return attention_mask
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training.
B: batch size;
......@@ -874,9 +796,7 @@ class Emformer(torch.nn.Module):
)
output = utterance
for layer in self.emformer_layers:
output, right_context, mems = layer(
output, lengths, right_context, mems, attention_mask
)
output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
return output.permute(1, 0, 2), lengths
@torch.jit.export
......
......@@ -15,13 +15,12 @@ class _TimeReduction(torch.nn.Module):
Args:
stride (int): number of frames to merge for each output frame.
"""
def __init__(self, stride: int) -> None:
super().__init__()
self.stride = stride
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass.
B: batch size;
......@@ -64,6 +63,7 @@ class _CustomLSTM(torch.nn.Module):
layer_norm_epsilon (float, optional): value of epsilon to use in
layer normalization layers (Default: 1e-5)
"""
def __init__(
self,
input_dim: int,
......@@ -179,7 +179,9 @@ class _Transcriber(torch.nn.Module):
) -> None:
super().__init__()
self.input_linear = torch.nn.Linear(
input_dim, time_reduction_input_dim, bias=False,
input_dim,
time_reduction_input_dim,
bias=False,
)
self.time_reduction = _TimeReduction(time_reduction_stride)
transformer_input_dim = time_reduction_input_dim * time_reduction_stride
......@@ -200,9 +202,7 @@ class _Transcriber(torch.nn.Module):
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training.
B: batch size;
......@@ -225,12 +225,8 @@ class _Transcriber(torch.nn.Module):
number of valid elements for i-th batch element in output frame sequences.
"""
input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction(
input_linear_out, lengths
)
transformer_out, transformer_lengths = self.transformer(
time_reduction_out, time_reduction_lengths
)
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths
......@@ -271,9 +267,7 @@ class _Transcriber(torch.nn.Module):
of ``infer``.
"""
input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction(
input_linear_out, lengths
)
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
(
transformer_out,
transformer_lengths,
......@@ -299,6 +293,7 @@ class _Predictor(torch.nn.Module):
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
"""
def __init__(
self,
num_symbols: int,
......@@ -368,9 +363,7 @@ class _Predictor(torch.nn.Module):
lstm_out = input_layer_norm_out
state_out: List[List[torch.Tensor]] = []
for layer_idx, lstm in enumerate(self.lstm_layers):
lstm_out, lstm_state_out = lstm(
lstm_out, None if state is None else state[layer_idx]
)
lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
lstm_out = self.dropout(lstm_out)
state_out.append(lstm_state_out)
......@@ -426,10 +419,7 @@ class _Joiner(torch.nn.Module):
output target lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 2 for i-th batch element in joint network output.
"""
joint_encodings = (
source_encodings.unsqueeze(2).contiguous()
+ target_encodings.unsqueeze(1).contiguous()
)
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
relu_out = self.relu(joint_encodings)
output = self.linear(relu_out)
return output, source_lengths, target_lengths
......@@ -447,9 +437,7 @@ class RNNT(torch.nn.Module):
joiner (torch.nn.Module): joint network.
"""
def __init__(
self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner
) -> None:
def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
super().__init__()
self.transcriber = transcriber
self.predictor = predictor
......@@ -500,10 +488,13 @@ class RNNT(torch.nn.Module):
of ``forward``.
"""
source_encodings, source_lengths = self.transcriber(
input=sources, lengths=source_lengths,
input=sources,
lengths=source_lengths,
)
target_encodings, target_lengths, predictor_state = self.predictor(
input=targets, lengths=target_lengths, state=predictor_state,
input=targets,
lengths=target_lengths,
state=predictor_state,
)
output, source_lengths, target_lengths = self.joiner(
source_encodings=source_encodings,
......@@ -558,7 +549,9 @@ class RNNT(torch.nn.Module):
@torch.jit.export
def transcribe(
self, sources: torch.Tensor, source_lengths: torch.Tensor,
self,
sources: torch.Tensor,
source_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Applies transcription network to sources in non-streaming mode.
......
......@@ -35,21 +35,14 @@ def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
for i in range(len(hypos[0].state)):
batched_state_components: List[torch.Tensor] = []
for j in range(len(hypos[0].state[i])):
batched_state_components.append(
torch.cat([hypo.state[i][j] for hypo in hypos])
)
batched_state_components.append(torch.cat([hypo.state[i][j] for hypo in hypos]))
states.append(batched_state_components)
return states
def _slice_state(
states: List[List[torch.Tensor]], idx: int, device: torch.device
) -> List[List[torch.Tensor]]:
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
idx_tensor = torch.tensor([idx], device=device)
return [
[state.index_select(0, idx_tensor) for state in state_tuple]
for state_tuple in states
]
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
......@@ -57,18 +50,14 @@ def _default_hypo_sort_key(hypo: Hypothesis) -> float:
def _compute_updated_scores(
hypos: List[Hypothesis], next_token_probs: torch.Tensor, beam_width: int,
hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1)
nonblank_scores = (
hypo_scores + next_token_probs[:, :-1]
) # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(
beam_width
)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(
nonblank_scores.shape[1], rounding_mode="trunc"
)
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
......@@ -114,9 +103,7 @@ class RNNTBeamSearch(torch.nn.Module):
self.step_max_tokens = step_max_tokens
def _init_b_hypos(
self, hypo: Optional[Hypothesis], device: torch.device
) -> List[Hypothesis]:
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
if hypo is not None:
token = hypo.tokens[-1]
state = hypo.state
......@@ -125,9 +112,7 @@ class RNNTBeamSearch(torch.nn.Module):
state = None
one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(
torch.tensor([[token]], device=device), one_tensor, state
)
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
init_hypo = Hypothesis(
tokens=[token],
predictor_out=pred_out[0].detach(),
......@@ -150,9 +135,7 @@ class RNNTBeamSearch(torch.nn.Module):
predictor_out,
torch.tensor([1] * len(hypos), device=device),
) # [beam_width, 1, 1, num_tokens]
joined_out = torch.nn.functional.log_softmax(
joined_out / self.temperature, dim=3
)
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
joined_out[:, :, :, :4].add_(-99999) # blank out invalid tokens
return joined_out[:, 0, 0]
......@@ -220,9 +203,7 @@ class RNNTBeamSearch(torch.nn.Module):
new_scores.append(score)
if base_hypos:
new_hypos = self._gen_new_hypos(
base_hypos, new_tokens, new_scores, t, device
)
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
else:
new_hypos: List[Hypothesis] = []
......@@ -239,7 +220,9 @@ class RNNTBeamSearch(torch.nn.Module):
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
states = _batch_state(base_hypos)
pred_out, _, pred_states = self.model.predict(
tgt_tokens, torch.tensor([1] * len(base_hypos), device=device), states,
tgt_tokens,
torch.tensor([1] * len(base_hypos), device=device),
states,
)
new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos):
......@@ -258,7 +241,10 @@ class RNNTBeamSearch(torch.nn.Module):
return new_hypos
def _search(
self, enc_out: torch.Tensor, hypo: Optional[Hypothesis], beam_width: int,
self,
enc_out: torch.Tensor,
hypo: Optional[Hypothesis],
beam_width: int,
) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1]
device = enc_out.device
......@@ -272,33 +258,35 @@ class RNNTBeamSearch(torch.nn.Module):
symbols_current_t = 0
while a_hypos:
next_token_probs = self._gen_next_token_probs(
enc_out[:, t: t + 1], a_hypos, device
)
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos(
b_hypos, a_hypos, next_token_probs, key_to_b_hypo,
b_hypos,
a_hypos,
next_token_probs,
key_to_b_hypo,
)
if symbols_current_t == self.step_max_tokens:
break
a_hypos = self._gen_a_hypos(
a_hypos, b_hypos, next_token_probs, t, beam_width, device,
a_hypos,
b_hypos,
next_token_probs,
t,
beam_width,
device,
)
if a_hypos:
symbols_current_t += 1
_, sorted_idx = torch.tensor(
[self.hypo_sort_key(hypo) for hypo in b_hypos]
).topk(beam_width)
_, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos
def forward(
self, input: torch.Tensor, length: torch.Tensor, beam_width: int
) -> List[Hypothesis]:
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
r"""Performs beam search for the given input sequence.
T: number of frames;
......@@ -319,9 +307,7 @@ class RNNTBeamSearch(torch.nn.Module):
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (
1,
), "length must be of shape () or (1,)"
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
if input.dim() == 0:
input = input.unsqueeze(0)
......@@ -367,9 +353,7 @@ class RNNTBeamSearch(torch.nn.Module):
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (
1,
), "length must be of shape () or (1,)"
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
if input.dim() == 0:
input = input.unsqueeze(0)
......
from torchaudio._internal import module_utils as _mod_utils
from .sox_effects import (
init_sox_effects,
shutdown_sox_effects,
......@@ -10,13 +11,14 @@ from .sox_effects import (
if _mod_utils.is_sox_available():
import atexit
init_sox_effects()
atexit.register(shutdown_sox_effects)
__all__ = [
'init_sox_effects',
'shutdown_sox_effects',
'effect_names',
'apply_effects_tensor',
'apply_effects_file',
"init_sox_effects",
"shutdown_sox_effects",
"effect_names",
"apply_effects_tensor",
"apply_effects_file",
]
......@@ -2,7 +2,6 @@ import os
from typing import List, Tuple, Optional
import torch
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import list_effects
......@@ -53,10 +52,10 @@ def effect_names() -> List[str]:
@_mod_utils.requires_sox()
def apply_effects_tensor(
tensor: torch.Tensor,
sample_rate: int,
effects: List[List[str]],
channels_first: bool = True,
tensor: torch.Tensor,
sample_rate: int,
effects: List[List[str]],
channels_first: bool = True,
) -> Tuple[torch.Tensor, int]:
"""Apply sox effects to given Tensor
......@@ -149,17 +148,16 @@ def apply_effects_tensor(
>>> waveform, sample_rate = transform(waveform, input_sample_rate)
>>> assert sample_rate == 8000
"""
return torch.ops.torchaudio.sox_effects_apply_effects_tensor(
tensor, sample_rate, effects, channels_first)
return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first)
@_mod_utils.requires_sox()
def apply_effects_file(
path: str,
effects: List[List[str]],
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
path: str,
effects: List[List[str]],
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Apply sox effects to the audio file and load the resulting data as Tensor
......@@ -265,9 +263,7 @@ def apply_effects_file(
>>> pass
"""
if not torch.jit.is_scripting():
if hasattr(path, 'read'):
return torchaudio._torchaudio.apply_effects_fileobj(
path, effects, normalize, channels_first, format)
if hasattr(path, "read"):
return torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
path = os.fspath(path)
return torch.ops.torchaudio.sox_effects_apply_effects_file(
path, effects, normalize, channels_first, format)
return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
......@@ -14,31 +14,31 @@ from .functional.functional import (
)
__all__ = [
'Spectrogram',
'InverseSpectrogram',
'GriffinLim',
'AmplitudeToDB',
'MelScale',
'InverseMelScale',
'MelSpectrogram',
'MFCC',
'LFCC',
'MuLawEncoding',
'MuLawDecoding',
'Resample',
'TimeStretch',
'Fade',
'FrequencyMasking',
'TimeMasking',
'SlidingWindowCmn',
'Vad',
'SpectralCentroid',
'Vol',
'ComputeDeltas',
'PitchShift',
'RNNTLoss',
'PSD',
'MVDR',
"Spectrogram",
"InverseSpectrogram",
"GriffinLim",
"AmplitudeToDB",
"MelScale",
"InverseMelScale",
"MelSpectrogram",
"MFCC",
"LFCC",
"MuLawEncoding",
"MuLawDecoding",
"Resample",
"TimeStretch",
"Fade",
"FrequencyMasking",
"TimeMasking",
"SlidingWindowCmn",
"Vad",
"SpectralCentroid",
"Vol",
"ComputeDeltas",
"PitchShift",
"RNNTLoss",
"PSD",
"MVDR",
]
......@@ -73,21 +73,23 @@ class Spectrogram(torch.nn.Module):
>>> spectrogram = transform(waveform)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: Optional[bool] = None) -> None:
__constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
def __init__(
self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.0,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: Optional[bool] = None,
) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......@@ -95,7 +97,7 @@ class Spectrogram(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.register_buffer("window", window)
self.pad = pad
self.power = power
self.normalized = normalized
......@@ -162,19 +164,21 @@ class InverseSpectrogram(torch.nn.Module):
>>> transform = transforms.InverseSpectrogram(n_fft=512)
>>> waveform = transform(spectrogram, length)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
__constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
def __init__(
self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
) -> None:
super(InverseSpectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......@@ -182,7 +186,7 @@ class InverseSpectrogram(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.register_buffer("window", window)
self.pad = pad
self.normalized = normalized
self.center = center
......@@ -242,31 +246,32 @@ class GriffinLim(torch.nn.Module):
>>> transform = transforms.GriffinLim(n_fft=512)
>>> waveform = transform(spectrogram)
"""
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power',
'length', 'momentum', 'rand_init']
def __init__(self,
n_fft: int = 400,
n_iter: int = 32,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.,
wkwargs: Optional[dict] = None,
momentum: float = 0.99,
length: Optional[int] = None,
rand_init: bool = True) -> None:
__constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"]
def __init__(
self,
n_fft: int = 400,
n_iter: int = 32,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.0,
wkwargs: Optional[dict] = None,
momentum: float = 0.99,
length: Optional[int] = None,
rand_init: bool = True,
) -> None:
super(GriffinLim, self).__init__()
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, "momentum={} < 0".format(momentum)
self.n_fft = n_fft
self.n_iter = n_iter
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.register_buffer("window", window)
self.length = length
self.power = power
self.momentum = momentum / (1 + momentum)
......@@ -282,8 +287,18 @@ class GriffinLim(torch.nn.Module):
Returns:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
self.n_iter, self.momentum, self.length, self.rand_init)
return F.griffinlim(
specgram,
self.window,
self.n_fft,
self.hop_length,
self.win_length,
self.power,
self.n_iter,
self.momentum,
self.length,
self.rand_init,
)
class AmplitudeToDB(torch.nn.Module):
......@@ -299,15 +314,15 @@ class AmplitudeToDB(torch.nn.Module):
top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable
number is 80. (Default: ``None``)
"""
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
__constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"]
def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
super(AmplitudeToDB, self).__init__()
self.stype = stype
if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value')
raise ValueError("top_db must be positive value")
self.top_db = top_db
self.multiplier = 10.0 if stype == 'power' else 20.0
self.multiplier = 10.0 if stype == "power" else 20.0
self.amin = 1e-10
self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
......@@ -344,16 +359,18 @@ class MelScale(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: int = 201,
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
def __init__(
self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.0,
f_max: Optional[float] = None,
n_stft: int = 201,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
......@@ -362,11 +379,9 @@ class MelScale(torch.nn.Module):
self.norm = norm
self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.melscale_fbanks(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb)
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
self.register_buffer("fb", fb)
def forward(self, specgram: Tensor) -> Tensor:
r"""
......@@ -404,21 +419,32 @@ class InverseMelScale(torch.nn.Module):
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']
def __init__(self,
n_stft: int,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
__constants__ = [
"n_stft",
"n_mels",
"sample_rate",
"f_min",
"f_max",
"max_iter",
"tolerance_loss",
"tolerance_change",
"sgdargs",
]
def __init__(
self,
n_stft: int,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.0,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
......@@ -427,13 +453,12 @@ class InverseMelScale(torch.nn.Module):
self.max_iter = max_iter
self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}
self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm, mel_scale)
self.register_buffer('fb', fb)
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
self.register_buffer("fb", fb)
def forward(self, melspec: Tensor) -> Tensor:
r"""
......@@ -452,12 +477,13 @@ class InverseMelScale(torch.nn.Module):
melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels
specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
dtype=melspec.dtype, device=melspec.device)
specgram = torch.rand(
melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
)
optim = torch.optim.SGD([specgram], **self.sgdargs)
loss = float('inf')
loss = float("inf")
for _ in range(self.max_iter):
optim.zero_grad()
diff = melspec - specgram.matmul(self.fb)
......@@ -527,26 +553,28 @@ class MelSpectrogram(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self,
sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
__constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"]
def __init__(
self,
sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.0,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.0,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
......@@ -558,19 +586,21 @@ class MelSpectrogram(torch.nn.Module):
self.n_mels = n_mels # number of mel frequency bins
self.f_max = f_max
self.f_min = f_min
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided)
self.spectrogram = Spectrogram(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad,
window_fn=window_fn,
power=self.power,
normalized=self.normalized,
wkwargs=wkwargs,
center=center,
pad_mode=pad_mode,
onesided=onesided,
)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
norm,
mel_scale
self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm, mel_scale
)
def forward(self, waveform: Tensor) -> Tensor:
......@@ -609,33 +639,35 @@ class MFCC(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self,
sample_rate: int = 16000,
n_mfcc: int = 40,
dct_type: int = 2,
norm: str = 'ortho',
log_mels: bool = False,
melkwargs: Optional[dict] = None) -> None:
__constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"]
def __init__(
self,
sample_rate: int = 16000,
n_mfcc: int = 40,
dct_type: int = 2,
norm: str = "ortho",
log_mels: bool = False,
melkwargs: Optional[dict] = None,
) -> None:
super(MFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported: {}'.format(dct_type))
raise ValueError("DCT type not supported: {}".format(dct_type))
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.dct_type = dct_type
self.norm = norm
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
melkwargs = melkwargs or {}
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
raise ValueError("Cannot select more MFCC coefficients than # mel bins")
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.register_buffer('dct_mat', dct_mat)
self.register_buffer("dct_mat", dct_mat)
self.log_mels = log_mels
def forward(self, waveform: Tensor) -> Tensor:
......@@ -685,22 +717,24 @@ class LFCC(torch.nn.Module):
:py:func:`torchaudio.functional.linear_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ['sample_rate', 'n_filter', 'n_lfcc', 'dct_type', 'top_db', 'log_lf']
def __init__(self,
sample_rate: int = 16000,
n_filter: int = 128,
f_min: float = 0.,
f_max: Optional[float] = None,
n_lfcc: int = 40,
dct_type: int = 2,
norm: str = 'ortho',
log_lf: bool = False,
speckwargs: Optional[dict] = None) -> None:
__constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"]
def __init__(
self,
sample_rate: int = 16000,
n_filter: int = 128,
f_min: float = 0.0,
f_max: Optional[float] = None,
n_lfcc: int = 40,
dct_type: int = 2,
norm: str = "ortho",
log_lf: bool = False,
speckwargs: Optional[dict] = None,
) -> None:
super(LFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported: {}'.format(dct_type))
raise ValueError("DCT type not supported: {}".format(dct_type))
self.sample_rate = sample_rate
self.f_min = f_min
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
......@@ -709,13 +743,13 @@ class LFCC(torch.nn.Module):
self.dct_type = dct_type
self.norm = norm
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
speckwargs = speckwargs or {}
self.Spectrogram = Spectrogram(**speckwargs)
if self.n_lfcc > self.Spectrogram.n_fft:
raise ValueError('Cannot select more LFCC coefficients than # fft bins')
raise ValueError("Cannot select more LFCC coefficients than # fft bins")
filter_mat = F.linear_fbanks(
n_freqs=self.Spectrogram.n_fft // 2 + 1,
......@@ -727,7 +761,7 @@ class LFCC(torch.nn.Module):
self.register_buffer("filter_mat", filter_mat)
dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
self.register_buffer('dct_mat', dct_mat)
self.register_buffer("dct_mat", dct_mat)
self.log_lf = log_lf
def forward(self, waveform: Tensor) -> Tensor:
......@@ -770,7 +804,7 @@ class MuLawEncoding(torch.nn.Module):
>>> mulawtrans = transform(waveform)
"""
__constants__ = ['quantization_channels']
__constants__ = ["quantization_channels"]
def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawEncoding, self).__init__()
......@@ -802,7 +836,7 @@ class MuLawDecoding(torch.nn.Module):
>>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512)
>>> mulawtrans = transform(waveform)
"""
__constants__ = ['quantization_channels']
__constants__ = ["quantization_channels"]
def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawDecoding, self).__init__()
......@@ -853,15 +887,15 @@ class Resample(torch.nn.Module):
"""
def __init__(
self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
*,
dtype: Optional[torch.dtype] = None,
self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = "sinc_interpolation",
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
*,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
......@@ -875,10 +909,16 @@ class Resample(torch.nn.Module):
if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta, dtype=dtype)
self.register_buffer('kernel', kernel)
self.orig_freq,
self.new_freq,
self.gcd,
self.lowpass_filter_width,
self.rolloff,
self.resampling_method,
beta,
dtype=dtype,
)
self.register_buffer("kernel", kernel)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -890,9 +930,7 @@ class Resample(torch.nn.Module):
"""
if self.orig_freq == self.new_freq:
return waveform
return _apply_sinc_resample_kernel(
waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, self.kernel, self.width)
class ComputeDeltas(torch.nn.Module):
......@@ -904,7 +942,7 @@ class ComputeDeltas(torch.nn.Module):
win_length (int, optional): The window length used for computing delta. (Default: ``5``)
mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``)
"""
__constants__ = ['win_length']
__constants__ = ["win_length"]
def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
super(ComputeDeltas, self).__init__()
......@@ -954,19 +992,16 @@ class TimeStretch(torch.nn.Module):
:alt: Spectrogram streched by 0.9
"""
__constants__ = ['fixed_rate']
__constants__ = ["fixed_rate"]
def __init__(self,
hop_length: Optional[int] = None,
n_freq: int = 201,
fixed_rate: Optional[float] = None) -> None:
def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None:
super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
self.register_buffer("phase_advance", torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
r"""
......@@ -983,8 +1018,7 @@ class TimeStretch(torch.nn.Module):
"""
if overriding_rate is None:
if self.fixed_rate is None:
raise ValueError(
"If no fixed_rate is specified, must pass a valid rate to the forward method.")
raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")
rate = self.fixed_rate
else:
rate = overriding_rate
......@@ -1007,10 +1041,7 @@ class Fade(torch.nn.Module):
>>> faded_waveform = transform(waveform)
"""
def __init__(self,
fade_in_len: int = 0,
fade_out_len: int = 0,
fade_shape: str = "linear") -> None:
def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:
super(Fade, self).__init__()
self.fade_in_len = fade_in_len
self.fade_out_len = fade_out_len
......@@ -1026,11 +1057,7 @@ class Fade(torch.nn.Module):
"""
waveform_length = waveform.size()[-1]
device = waveform.device
return (
self._fade_in(waveform_length, device)
* self._fade_out(waveform_length, device)
* waveform
)
return self._fade_in(waveform_length, device) * self._fade_out(waveform_length, device) * waveform
def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len, device=device)
......@@ -1043,7 +1070,7 @@ class Fade(torch.nn.Module):
fade = torch.pow(2, (fade - 1)) * fade
if self.fade_shape == "logarithmic":
fade = torch.log10(.1 + fade) + 1
fade = torch.log10(0.1 + fade) + 1
if self.fade_shape == "quarter_sine":
fade = torch.sin(fade * math.pi / 2)
......@@ -1058,10 +1085,10 @@ class Fade(torch.nn.Module):
ones = torch.ones(waveform_length - self.fade_out_len, device=device)
if self.fade_shape == "linear":
fade = - fade + 1
fade = -fade + 1
if self.fade_shape == "exponential":
fade = torch.pow(2, - fade) * (1 - fade)
fade = torch.pow(2, -fade) * (1 - fade)
if self.fade_shape == "logarithmic":
fade = torch.log10(1.1 - fade) + 1
......@@ -1084,7 +1111,7 @@ class _AxisMasking(torch.nn.Module):
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D.
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
__constants__ = ["mask_param", "axis", "iid_masks"]
def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
......@@ -1093,7 +1120,7 @@ class _AxisMasking(torch.nn.Module):
self.axis = axis
self.iid_masks = iid_masks
def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
r"""
Args:
specgram (Tensor): Tensor of dimension `(..., freq, time)`.
......@@ -1180,12 +1207,12 @@ class Vol(torch.nn.Module):
gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
"""
def __init__(self, gain: float, gain_type: str = 'amplitude'):
def __init__(self, gain: float, gain_type: str = "amplitude"):
super(Vol, self).__init__()
self.gain = gain
self.gain_type = gain_type
if gain_type in ['amplitude', 'power'] and gain < 0:
if gain_type in ["amplitude", "power"] and gain < 0:
raise ValueError("If gain_type = amplitude or power, gain must be positive.")
def forward(self, waveform: Tensor) -> Tensor:
......@@ -1221,11 +1248,9 @@ class SlidingWindowCmn(torch.nn.Module):
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
"""
def __init__(self,
cmn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
norm_vars: bool = False) -> None:
def __init__(
self, cmn_window: int = 600, min_cmn_window: int = 100, center: bool = False, norm_vars: bool = False
) -> None:
super().__init__()
self.cmn_window = cmn_window
self.min_cmn_window = min_cmn_window
......@@ -1240,8 +1265,7 @@ class SlidingWindowCmn(torch.nn.Module):
Returns:
Tensor: Tensor of spectrogram of dimension `(..., time, freq)`.
"""
cmn_specgram = F.sliding_window_cmn(
specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
cmn_specgram = F.sliding_window_cmn(specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
return cmn_specgram
......@@ -1297,24 +1321,26 @@ class Vad(torch.nn.Module):
- http://sox.sourceforge.net/sox.html
"""
def __init__(self,
sample_rate: int,
trigger_level: float = 7.0,
trigger_time: float = 0.25,
search_time: float = 1.0,
allowed_gap: float = 0.25,
pre_trigger_time: float = 0.0,
boot_time: float = .35,
noise_up_time: float = .1,
noise_down_time: float = .01,
noise_reduction_amount: float = 1.35,
measure_freq: float = 20.0,
measure_duration: Optional[float] = None,
measure_smooth_time: float = .4,
hp_filter_freq: float = 50.,
lp_filter_freq: float = 6000.,
hp_lifter_freq: float = 150.,
lp_lifter_freq: float = 2000.) -> None:
def __init__(
self,
sample_rate: int,
trigger_level: float = 7.0,
trigger_time: float = 0.25,
search_time: float = 1.0,
allowed_gap: float = 0.25,
pre_trigger_time: float = 0.0,
boot_time: float = 0.35,
noise_up_time: float = 0.1,
noise_down_time: float = 0.01,
noise_reduction_amount: float = 1.35,
measure_freq: float = 20.0,
measure_duration: Optional[float] = None,
measure_smooth_time: float = 0.4,
hp_filter_freq: float = 50.0,
lp_filter_freq: float = 6000.0,
hp_lifter_freq: float = 150.0,
lp_lifter_freq: float = 2000.0,
) -> None:
super().__init__()
self.sample_rate = sample_rate
......@@ -1386,23 +1412,25 @@ class SpectralCentroid(torch.nn.Module):
>>> transform = transforms.SpectralCentroid(sample_rate)
>>> spectral_centroid = transform(waveform) # (channel, time)
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']
def __init__(self,
sample_rate: int,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
__constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"]
def __init__(
self,
sample_rate: int,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.register_buffer("window", window)
self.pad = pad
def forward(self, waveform: Tensor) -> Tensor:
......@@ -1414,8 +1442,9 @@ class SpectralCentroid(torch.nn.Module):
Tensor: Spectral Centroid of size `(..., time)`.
"""
return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
self.win_length)
return F.spectral_centroid(
waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, self.win_length
)
class PitchShift(torch.nn.Module):
......@@ -1438,17 +1467,19 @@ class PitchShift(torch.nn.Module):
>>> transform = transforms.PitchShift(sample_rate, 4)
>>> waveform_shift = transform(waveform) # (channel, time)
"""
__constants__ = ['sample_rate', 'n_steps', 'bins_per_octave', 'n_fft', 'win_length', 'hop_length']
def __init__(self,
sample_rate: int,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]
def __init__(
self,
sample_rate: int,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
......@@ -1457,7 +1488,7 @@ class PitchShift(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.register_buffer("window", window)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -1468,8 +1499,16 @@ class PitchShift(torch.nn.Module):
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft,
self.win_length, self.hop_length, self.window)
return F.pitch_shift(
waveform,
self.sample_rate,
self.n_steps,
self.bins_per_octave,
self.n_fft,
self.win_length,
self.hop_length,
self.window,
)
class RNNTLoss(torch.nn.Module):
......@@ -1506,7 +1545,7 @@ class RNNTLoss(torch.nn.Module):
def __init__(
self,
blank: int = -1,
clamp: float = -1.,
clamp: float = -1.0,
reduction: str = "mean",
):
super().__init__()
......@@ -1532,15 +1571,7 @@ class RNNTLoss(torch.nn.Module):
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return F.rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.reduction
)
return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction)
def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
......@@ -1557,8 +1588,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
torch.Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2],\
"The size of ``dim1`` and ``dim2`` must be the same."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
......@@ -1684,8 +1714,11 @@ class MVDR(torch.nn.Module):
online: bool = False,
):
super().__init__()
assert solution in ["ref_channel", "stv_evd", "stv_power"],\
"Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
assert solution in [
"ref_channel",
"stv_evd",
"stv_power",
], "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
self.ref_channel = ref_channel
self.solution = solution
self.multi_mask = multi_mask
......@@ -1698,10 +1731,10 @@ class MVDR(torch.nn.Module):
psd_n: torch.Tensor = torch.zeros(1)
mask_sum_s: torch.Tensor = torch.zeros(1)
mask_sum_n: torch.Tensor = torch.zeros(1)
self.register_buffer('psd_s', psd_s)
self.register_buffer('psd_n', psd_n)
self.register_buffer('mask_sum_s', mask_sum_s)
self.register_buffer('mask_sum_n', mask_sum_n)
self.register_buffer("psd_s", psd_s)
self.register_buffer("psd_n", psd_n)
self.register_buffer("mask_sum_s", mask_sum_s)
self.register_buffer("mask_sum_n", mask_sum_n)
def _get_updated_mvdr_vector(
self,
......@@ -1710,7 +1743,7 @@ class MVDR(torch.nn.Module):
mask_s: torch.Tensor,
mask_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = 'ref_channel',
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
......@@ -1788,7 +1821,7 @@ class MVDR(torch.nn.Module):
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = 'ref_channel',
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
......@@ -1851,10 +1884,7 @@ class MVDR(torch.nn.Module):
return stv
def _get_steering_vector_power(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor
self, psd_s: torch.Tensor, psd_n: torch.Tensor, reference_vector: torch.Tensor
) -> torch.Tensor:
r"""Estimate the steering vector by the power method.
......@@ -1876,11 +1906,7 @@ class MVDR(torch.nn.Module):
stv = torch.matmul(psd_s, stv)
return stv
def _apply_beamforming_vector(
self,
specgram: torch.Tensor,
beamform_vector: torch.Tensor
) -> torch.Tensor:
def _apply_beamforming_vector(self, specgram: torch.Tensor, beamform_vector: torch.Tensor) -> torch.Tensor:
r"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
......@@ -1896,12 +1922,7 @@ class MVDR(torch.nn.Module):
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram])
return specgram_enhanced
def _tik_reg(
self,
mat: torch.Tensor,
reg: float = 1e-7,
eps: float = 1e-8
) -> torch.Tensor:
def _tik_reg(self, mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
......@@ -1922,10 +1943,7 @@ class MVDR(torch.nn.Module):
return mat
def forward(
self,
specgram: torch.Tensor,
mask_s: torch.Tensor,
mask_n: Optional[torch.Tensor] = None
self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Perform MVDR beamforming.
......@@ -1946,9 +1964,7 @@ class MVDR(torch.nn.Module):
"""
dtype = specgram.dtype
if specgram.ndim < 3:
raise ValueError(
f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}"
)
raise ValueError(f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}")
if not specgram.is_complex():
raise ValueError(
f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
......@@ -1958,9 +1974,7 @@ class MVDR(torch.nn.Module):
specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
if mask_n is None:
warnings.warn(
"``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``."
)
warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
mask_n = 1 - mask_s
shape = specgram.size()
......@@ -1977,33 +1991,15 @@ class MVDR(torch.nn.Module):
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
u = torch.zeros(
specgram.size()[:-2],
device=specgram.device,
dtype=torch.cdouble
) # (..., channel)
u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
u[..., self.ref_channel].fill_(1)
if self.online:
w_mvdr = self._get_updated_mvdr_vector(
psd_s,
psd_n,
mask_s,
mask_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
)
else:
w_mvdr = self._get_mvdr_vector(
psd_s,
psd_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
)
w_mvdr = self._get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr)
......
from torchaudio._internal import module_utils as _mod_utils
from . import (
sox_utils,
)
from torchaudio._internal import module_utils as _mod_utils
if _mod_utils.is_sox_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment