"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b2d9c3e315405f2b5cfdfa5b93f849d5b27a4109"
Commit b35a5fcf authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix ConformerWav2Vec2PretrainModel (#3085)

Summary:
The negative sampling should be applied to unmasked features in masked indices, the PR fixes the logic in ConformerWav2Vec2PretrainModel.

Pull Request resolved: https://github.com/pytorch/audio/pull/3085

Reviewed By: mthrok

Differential Revision: D43488570

Pulled By: nateanl

fbshipit-source-id: 3820400d50b74216bb98ca6a40dc6a7acca01564
parent 3267c7ed
......@@ -318,9 +318,14 @@ class ConformerWav2Vec2PretrainModel(Module):
x = self.wav2vec2.encoder.feature_projection.layer_norm(x)
x = self.wav2vec2.encoder.feature_projection.dropout(x)
x, mask_idxs = self.mask_generator(x, padding_mask)
targets, negs, neg_idxs = self.negative_sampler(x)
# Unmasked feature is used to generate positive and negative samples.
unmasked_x = x.clone()
# Apply masking to x before passing it to Conformer layers.
x, mask_idxs = self.mask_generator(x, padding_mask)
# Select the frames from masked indices for negative sampling.
unmasked_x = unmasked_x[mask_idxs].view(x.shape[0], -1, x.shape[-1])
targets, negs, neg_idxs = self.negative_sampler(unmasked_x)
x = self.wav2vec2.encoder.feature_projection.projection(x)
x = x.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment