"vscode:/vscode.git/clone" did not exist on "c94ea5fff57b3ec8893f073dce61388c163afb67"
Unverified Commit 56f3b927 authored by moto's avatar moto Committed by GitHub
Browse files

Allow the customization of axis exclusion for ASR head (#1932)

In Wav2Vec2 ASR pipelines, the `get_model` method performs on-the-fly model
surgery to remove unused dimensions common to all the Wav2Vec2 model trained
with fairseq.

In VoxPopuli, there seems to be an extra dimensions introduced due to
some issue in the preprocessing stage.

For example, the label `1` shows up in the training dataset of German (1 out of 16M),
English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M).

This code changes will allow the customization of excluded dimensions for such cases.
parent d35ea80e
......@@ -55,6 +55,12 @@ class Wav2Vec2Bundle:
"""
return self._sample_rate
def _get_state_dict(self, dl_kwargs):
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
......@@ -68,18 +74,7 @@ class Wav2Vec2Bundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
"""
model = wav2vec2_model(**self._params)
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if model.aux is not None:
# For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from fairseq but not used, so we remove it here.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
for key in ['aux.weight', 'aux.bias']:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)])
model.load_state_dict(state_dict)
model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval()
return model
......@@ -126,6 +121,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501
_labels: Tuple[str]
_remove_aux_axis: Tuple[int] = (1, 2, 3)
def get_labels(
self,
......@@ -151,6 +147,25 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
""" # noqa: E501
return (blank, *self._labels)
def _get_state_dict(self, dl_kwargs):
state_dict = super()._get_state_dict(dl_kwargs)
if self._remove_aux_axis:
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ['aux.weight', 'aux.bias']:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict
WAV2VEC2_BASE = Wav2Vec2Bundle(
_path='wav2vec2_fairseq_base_ls960.pth',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment