Commit 614a7dac authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add feature mean square value to HuBERT Pretrain model output (#2128)

Summary:
In [Fairseq](https://github.com/pytorch/fairseq/blob/main/examples/hubert/config/pretrain/hubert_base_librispeech.yaml#L48), the training applies additional penalty loss besides the cross-entropy losses. This PR adds the feature's mean square value to the model output to support such penalty loss.

Pull Request resolved: https://github.com/pytorch/audio/pull/2128

Reviewed By: mthrok

Differential Revision: D33403972

Pulled By: nateanl

fbshipit-source-id: f08fefa2c975a847c6075171b310f57c1980309d
parent df0175e8
......@@ -172,15 +172,19 @@ class HuBERTPretrainModel(Module):
have valid length. Default: ``None``.
Returns:
(Tensor, Tensor):
(Tensor, Tensor, Tensor):
Tensor
The masked sequences of probability distribution (in logit).
Shape: `(masked_frames, num labels)`.
Tensor
The unmasked sequence of probability distribution (in logit).
Shape: `(unmasked_frames, num labels)`.
Tensor
The feature mean value for additional penalty loss.
Shape: `(1,)`.
"""
x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
features_pen = x.float().pow(2).mean()
if lengths is not None:
padding_mask = components._get_padding_mask(x, lengths)
else:
......@@ -188,7 +192,8 @@ class HuBERTPretrainModel(Module):
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
x, mask = self.mask_generator(x, padding_mask)
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
if padding_mask:
assert x.shape[1] == labels.shape[1], "The length of label must match that of HuBERT model output"
if padding_mask is not None:
mask_m = torch.logical_and(~padding_mask, mask)
mask_u = torch.logical_and(~padding_mask, ~mask_m)
else:
......@@ -197,7 +202,7 @@ class HuBERTPretrainModel(Module):
logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
return logit_m, logit_u
return logit_m, logit_u, features_pen
def wav2vec2_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment