Unverified Commit ec4837dc authored by moto's avatar moto Committed by GitHub
Browse files

[BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR (#1914)

* [BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR

The Wav2Vec2 ASR pretrained weights originated from fairseq have
extra dimension that have nothing to do with the ASR task.

https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L18-L37

which is masked during the loss computation as

https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L128

This change removes it.

* Use '-' for blank token representation.
parent ec125053
......@@ -4,8 +4,9 @@ import pytest
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels):
def __init__(self, labels, blank: int = 0):
super().__init__()
self.blank = blank
self.labels = labels
def forward(self, logits: torch.Tensor) -> str:
......@@ -21,9 +22,8 @@ class GreedyCTCDecoder(torch.nn.Module):
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = []
for i in best_path:
char = self.labels[i]
if char not in ['<s>', '<pad>']:
hypothesis.append(char)
if i != self.blank:
hypothesis.append(self.labels[i])
return ''.join(hypothesis)
......
from dataclasses import dataclass
from typing import Dict, Tuple, Any
import torch
from torch.hub import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model
......@@ -68,6 +69,14 @@ class Wav2Vec2Bundle:
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if model.aux is not None:
# For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from fairseq but not used, so we remove it here.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
for key in ['aux.weight', 'aux.bias']:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)])
model.load_state_dict(state_dict)
model.eval()
return model
......@@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> # Check the corresponding labels of the output.
>>> labels = bundle.get_labels()
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
......@@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
def get_labels(
self,
*,
bos: str = '<s>',
pad: str = '<pad>',
eos: str = '</s>',
unk: str = '<unk>',
blank: str = '-',
) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
The first is blank token, and it is customizable.
Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
blank (str, optional): Blank token. (default: ``'-'``)
Returns:
Tuple[str]:
......@@ -142,11 +145,9 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
""" # noqa: E501
if self._labels is None:
raise ValueError('Pre-trained models do not have labels.')
return (bos, pad, eos, unk, *self._labels)
return (blank, *self._labels)
def _get_labels():
......@@ -252,7 +253,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -298,7 +299,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -344,7 +345,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -433,7 +434,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -479,7 +480,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -525,7 +526,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -614,7 +615,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -660,7 +661,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -706,7 +707,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -932,7 +933,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 32,
'aux_num_out': 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......@@ -979,7 +980,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 32,
'aux_num_out': 29,
},
_labels=_get_labels(),
_sample_rate=16000,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment