"vscode:/vscode.git/clone" did not exist on "e012579d3dedfcd472e81ee7b7ba2cf30168afc8"
Commit 614a7dac authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add feature mean square value to HuBERT Pretrain model output (#2128)

Summary:
In [Fairseq](https://github.com/pytorch/fairseq/blob/main/examples/hubert/config/pretrain/hubert_base_librispeech.yaml#L48), the training applies additional penalty loss besides the cross-entropy losses. This PR adds the feature's mean square value to the model output to support such penalty loss.

Pull Request resolved: https://github.com/pytorch/audio/pull/2128

Reviewed By: mthrok

Differential Revision: D33403972

Pulled By: nateanl

fbshipit-source-id: f08fefa2c975a847c6075171b310f57c1980309d
parent df0175e8
...@@ -172,15 +172,19 @@ class HuBERTPretrainModel(Module): ...@@ -172,15 +172,19 @@ class HuBERTPretrainModel(Module):
have valid length. Default: ``None``. have valid length. Default: ``None``.
Returns: Returns:
(Tensor, Tensor): (Tensor, Tensor, Tensor):
Tensor Tensor
The masked sequences of probability distribution (in logit). The masked sequences of probability distribution (in logit).
Shape: `(masked_frames, num labels)`. Shape: `(masked_frames, num labels)`.
Tensor Tensor
The unmasked sequence of probability distribution (in logit). The unmasked sequence of probability distribution (in logit).
Shape: `(unmasked_frames, num labels)`. Shape: `(unmasked_frames, num labels)`.
Tensor
The feature mean value for additional penalty loss.
Shape: `(1,)`.
""" """
x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths) x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
features_pen = x.float().pow(2).mean()
if lengths is not None: if lengths is not None:
padding_mask = components._get_padding_mask(x, lengths) padding_mask = components._get_padding_mask(x, lengths)
else: else:
...@@ -188,7 +192,8 @@ class HuBERTPretrainModel(Module): ...@@ -188,7 +192,8 @@ class HuBERTPretrainModel(Module):
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths) x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
x, mask = self.mask_generator(x, padding_mask) x, mask = self.mask_generator(x, padding_mask)
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
if padding_mask: assert x.shape[1] == labels.shape[1], "The length of label must match that of HuBERT model output"
if padding_mask is not None:
mask_m = torch.logical_and(~padding_mask, mask) mask_m = torch.logical_and(~padding_mask, mask)
mask_u = torch.logical_and(~padding_mask, ~mask_m) mask_u = torch.logical_and(~padding_mask, ~mask_m)
else: else:
...@@ -197,7 +202,7 @@ class HuBERTPretrainModel(Module): ...@@ -197,7 +202,7 @@ class HuBERTPretrainModel(Module):
logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u) logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
return logit_m, logit_u return logit_m, logit_u, features_pen
def wav2vec2_model( def wav2vec2_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment