Commit 7e85f625 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix FA bundle (#3538)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3538

Reviewed By: huangruizhe

Differential Revision: D48154056

Pulled By: mthrok

fbshipit-source-id: 72f58c501c5302d40f1d059f95bd6fe40d4a52aa
parent e6c89731
......@@ -32,12 +32,12 @@ class Tokenizer(ITokenizer):
return [[self.dictionary[c] for c in word] for word in transcript]
def _align_emission_and_tokens(emission: Tensor, tokens: List[int]):
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
aligned_tokens, scores = F.forced_align(emission, targets, 0)
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
......@@ -50,7 +50,7 @@ class IAligner(ABC):
"""Generate list of time-stamped token sequences
Args:
emission (Tensor): Sequence of token probability distributions.
emission (Tensor): Sequence of token probability distributions in log-domain.
Shape: `(time, tokens)`.
tokens (list of integer sequence): Tokenized transcript.
Output from :py:class:`Wav2Vec2FABundle.Tokenizer`.
......@@ -75,11 +75,13 @@ def _flatten(nested_list):
class Aligner(IAligner):
def __init__(self, blank):
self.blank = blank
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
emission = torch.log_softmax(emission, dim=-1)
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens))
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
......@@ -93,7 +92,7 @@ class Wav2Vec2Bundle:
state_dict = self._get_state_dict(dl_kwargs)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model = utils._extend_model(model, normalize_waveform=True)
model.eval()
return model
......@@ -1587,11 +1586,6 @@ class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
labels = super().get_labels(blank=blank)
return labels if star is None else (*labels, star)
def _get_params_with_star(self):
params = copy.deepcopy(self._params)
params["aux_num_out"] += 1
return params
def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
......@@ -1605,13 +1599,19 @@ class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
Returns:
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
.. note::
The model created with this method returns probability in log-domain,
(i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas
the other Wav2Vec2 models returns logit.
"""
params = self._get_params_with_star() if with_star else self._params
model = utils._get_model(self._model_type, params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star)
model = utils._get_model(self._model_type, self._params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model = utils._extend_model(
model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star
)
model.eval()
return model
......@@ -1650,7 +1650,7 @@ class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
Returns:
Aligner
"""
return aligner.Aligner()
return aligner.Aligner(blank=0)
MMS_FA = Wav2Vec2FABundle(
......
......@@ -24,13 +24,23 @@ class _Wav2Vec2Model(nn.Module):
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
super().__init__()
self.model = model
self.normalize_waveform = normalize_waveform
self.apply_log_softmax = apply_log_softmax
self.append_star = append_star
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
if self.normalize_waveform:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
output, output_lengths = self.model(waveforms, lengths)
if self.apply_log_softmax:
output = torch.nn.functional.log_softmax(output, dim=-1)
if self.append_star:
star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
output = torch.cat((output, star_dim), dim=-1)
return output, output_lengths
@torch.jit.export
def extract_features(
......@@ -39,13 +49,14 @@ class _Wav2Vec2Model(nn.Module):
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
if self.normalize_waveform:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
def _apply_input_layer_norm(module):
"""Add extra layer_norm to the model"""
return _Wav2Vec2Model(module)
def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
"""Add extra transformations to the model"""
return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)
def _remove_aux_axes(state_dict, axes):
......@@ -65,23 +76,13 @@ def _remove_aux_axes(state_dict, axes):
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
def _add_star_dim(state_dict):
w, b = state_dict["aux.weight"], state_dict["aux.bias"]
zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype)
state_dict["aux.weight"] = torch.cat((zeros, w), dim=0)
ones = torch.ones((1,), device=b.device, dtype=b.dtype)
state_dict["aux.bias"] = torch.cat((b, ones), dim=0)
def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False):
def _get_state_dict(url, dl_kwargs, remove_axes=None):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
if add_star:
_add_star_dim(state_dict)
return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment