Unverified Commit ec4837dc authored by moto's avatar moto Committed by GitHub
Browse files

[BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR (#1914)

* [BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR

The Wav2Vec2 ASR pretrained weights originated from fairseq have
extra dimension that have nothing to do with the ASR task.

https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L18-L37

which is masked during the loss computation as

https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L128

This change removes it.

* Use '-' for blank token representation.
parent ec125053
...@@ -4,8 +4,9 @@ import pytest ...@@ -4,8 +4,9 @@ import pytest
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels): def __init__(self, labels, blank: int = 0):
super().__init__() super().__init__()
self.blank = blank
self.labels = labels self.labels = labels
def forward(self, logits: torch.Tensor) -> str: def forward(self, logits: torch.Tensor) -> str:
...@@ -21,9 +22,8 @@ class GreedyCTCDecoder(torch.nn.Module): ...@@ -21,9 +22,8 @@ class GreedyCTCDecoder(torch.nn.Module):
best_path = torch.unique_consecutive(best_path, dim=-1) best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = [] hypothesis = []
for i in best_path: for i in best_path:
char = self.labels[i] if i != self.blank:
if char not in ['<s>', '<pad>']: hypothesis.append(self.labels[i])
hypothesis.append(char)
return ''.join(hypothesis) return ''.join(hypothesis)
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Any from typing import Dict, Tuple, Any
import torch
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model from torchaudio.models import wav2vec2_model, Wav2Vec2Model
...@@ -68,6 +69,14 @@ class Wav2Vec2Bundle: ...@@ -68,6 +69,14 @@ class Wav2Vec2Bundle:
url = f'https://download.pytorch.org/torchaudio/models/{self._path}' url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
if model.aux is not None:
# For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from fairseq but not used, so we remove it here.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
for key in ['aux.weight', 'aux.bias']:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)])
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
return model return model
...@@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> # Check the corresponding labels of the output. >>> # Check the corresponding labels of the output.
>>> labels = bundle.get_labels() >>> labels = bundle.get_labels()
>>> print(labels) >>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> >>>
>>> # Resample audio to the expected sampling rate >>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
...@@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
def get_labels( def get_labels(
self, self,
*, *,
bos: str = '<s>', blank: str = '-',
pad: str = '<pad>',
eos: str = '</s>',
unk: str = '<unk>',
) -> Tuple[str]: ) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles) """The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized. The first is blank token, and it is customizable.
Args: Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``) blank (str, optional): Blank token. (default: ``'-'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns: Returns:
Tuple[str]: Tuple[str]:
...@@ -142,11 +145,9 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -142,11 +145,9 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example Example
>>> import torchaudio >>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
""" # noqa: E501 """ # noqa: E501
if self._labels is None: return (blank, *self._labels)
raise ValueError('Pre-trained models do not have labels.')
return (bos, pad, eos, unk, *self._labels)
def _get_labels(): def _get_labels():
...@@ -252,7 +253,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( ...@@ -252,7 +253,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
'encoder_dropout': 0.1, 'encoder_dropout': 0.1,
'encoder_layer_norm_first': False, 'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -298,7 +299,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( ...@@ -298,7 +299,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
'encoder_dropout': 0.1, 'encoder_dropout': 0.1,
'encoder_layer_norm_first': False, 'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -344,7 +345,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle( ...@@ -344,7 +345,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.1, "encoder_dropout": 0.1,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05, "encoder_layer_drop": 0.05,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -433,7 +434,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle( ...@@ -433,7 +434,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -479,7 +480,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle( ...@@ -479,7 +480,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -525,7 +526,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle( ...@@ -525,7 +526,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": False, "encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -614,7 +615,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle( ...@@ -614,7 +615,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": True, "encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -660,7 +661,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle( ...@@ -660,7 +661,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": True, "encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -706,7 +707,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle( ...@@ -706,7 +707,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0, "encoder_dropout": 0.0,
"encoder_layer_norm_first": True, "encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 32, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -932,7 +933,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( ...@@ -932,7 +933,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'encoder_dropout': 0.0, 'encoder_dropout': 0.0,
'encoder_layer_norm_first': True, 'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1, 'encoder_layer_drop': 0.1,
'aux_num_out': 32, 'aux_num_out': 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
...@@ -979,7 +980,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( ...@@ -979,7 +980,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'encoder_dropout': 0.0, 'encoder_dropout': 0.0,
'encoder_layer_norm_first': True, 'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1, 'encoder_layer_drop': 0.1,
'aux_num_out': 32, 'aux_num_out': 29,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000, _sample_rate=16000,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment