Commit 6de7bb98 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Improve wav2vec2/hubert model for pre-training (#2716)

Summary:
This PR improves the Wav2Vec2/HuBERT model regarding model pre-training.

- The model initialization of positional embedding and transformer module is essential to model pre-training. The accuracy of unmasked frames should be higher than masked frames, as it is an easier task. but without the initialization, the accuracy of masked frames is higher than unmasked frames.
  Compared the performance after two epochs with 16 GPUs.
  - With model initialization, the accuracies of masked/unmasked frames are 0.08/0.11.
  - Without model initialization, the accuracies of masked/unmasked frames are 0.06/0.04.
- After adding the model initialization, the gradient is easy to overflow (aka `nan` gradient). In paper [Self-Supervised Learning for speech recognition with Intermediate layer supervision](https://arxiv.org/abs/2112.08778) the authors propose a simple but effective method to mitigate the overflow issue, by scaling down the multiplication of query and key and subtracting the maximum value from it (subtracting a constant value won't change the output of softmax). Then it guarantees the value won't be overflowed.
- In the original fairseq, the mask indices are generated by `numpy.random.choice`. Here replace `torch.multinomial` with `torch.randperm`. (cc carolineechen).

Other improvements within training scripts will be included in a separate PR.

Pull Request resolved: https://github.com/pytorch/audio/pull/2716

Reviewed By: xiaohui-zhang

Differential Revision: D39832189

Pulled By: nateanl

fbshipit-source-id: f4d2a473a79ad63add2dd16624bd155d5ce4de27
parent 8b2fbf28
...@@ -8,6 +8,36 @@ from torch.nn import Module, Parameter ...@@ -8,6 +8,36 @@ from torch.nn import Module, Parameter
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
def _init_transformer_params(module):
"""
Initialize the weights of Transformer module in Wav2Vec2/HuBERT.
If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02.
If ``bias`` is set to ``True`` in the module, set ``bias`` to 0.
If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02.
If ``padding_idx`` is not None, set the weight of padding to 0.
Note:
Ths method corresponds to
`init_bert_params
<https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_sentence_encoder.py#L21>`__
in the original ``fairseq`` implementation.
"""
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose""" """Layer norm with transpose"""
...@@ -174,6 +204,7 @@ class ConvolutionalPositionalEmbedding(Module): ...@@ -174,6 +204,7 @@ class ConvolutionalPositionalEmbedding(Module):
padding=kernel_size // 2, padding=kernel_size // 2,
groups=groups, groups=groups,
) )
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
...@@ -267,9 +298,14 @@ class SelfAttention(Module): ...@@ -267,9 +298,14 @@ class SelfAttention(Module):
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
weights = self.scaling * (q @ k) # B, nH, L, L # scale down q to avoid value overflow.
weights = (self.scaling * q) @ k # B, nH, L, L
if attention_mask is not None: if attention_mask is not None:
weights += attention_mask weights += attention_mask
# subtracting a constant value from the tensor won't change the output of softmax.
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
weights = weights - weights.max(dim=-1, keepdim=True)[0]
weights = torch.nn.functional.softmax(weights, dim=-1) weights = torch.nn.functional.softmax(weights, dim=-1)
weights = self.dropout(weights) weights = self.dropout(weights)
...@@ -817,8 +853,7 @@ def _compute_mask_indices( ...@@ -817,8 +853,7 @@ def _compute_mask_indices(
if sz - min_len <= num_mask: if sz - min_len <= num_mask:
min_len = sz - num_mask - 1 min_len = sz - num_mask - 1
mask_idc = torch.multinomial(torch.ones((sz - min_len,)), num_samples=num_mask, replacement=False) mask_idc = torch.randperm(sz - min_len)[:num_mask]
mask_idc = torch.tensor( mask_idc = torch.tensor(
[mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])] [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
) )
...@@ -828,15 +863,7 @@ def _compute_mask_indices( ...@@ -828,15 +863,7 @@ def _compute_mask_indices(
min_len = min([len(m) for m in mask_idcs]) min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs): for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len: if len(mask_idc) > min_len:
mask_idc = torch.index_select( mask_idc = mask_idc[torch.randperm(len(mask_idc))[:min_len].long()]
mask_idc,
0,
torch.multinomial(
torch.ones((mask_idc.shape[0],)),
num_samples=min_len,
replacement=False,
),
)
mask[i, mask_idc] = True mask[i, mask_idc] = True
return mask return mask
......
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -681,6 +682,27 @@ def hubert_xlarge( ...@@ -681,6 +682,27 @@ def hubert_xlarge(
) )
def _init_hubert_pretrain_model(module):
if isinstance(module, components.LayerNorm):
torch.nn.init.kaiming_normal_(module.conv.weight)
elif isinstance(module, components.ConvolutionalPositionalEmbedding):
# normalize the weight to normal distribution.
std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size))
torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std)
torch.nn.init.constant_(module.conv.bias, 0.0)
elif isinstance(module, components.SelfAttention):
# normalize the query, key, value, and out_proj parameters in self attention module.
torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2))
torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2))
torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2))
torch.nn.init.xavier_uniform_(module.out_proj.weight)
torch.nn.init.constant_(module.out_proj.bias, 0.0)
elif isinstance(module, components.Transformer):
module.apply(components._init_transformer_params)
else:
pass
def hubert_pretrain_model( def hubert_pretrain_model(
extractor_mode: str, extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
...@@ -963,12 +985,15 @@ def hubert_pretrain_model( ...@@ -963,12 +985,15 @@ def hubert_pretrain_model(
skip_masked, skip_masked,
skip_nomask, skip_nomask,
) )
return HuBERTPretrainModel( model = HuBERTPretrainModel(
wav2vec2=wav2vec2, wav2vec2=wav2vec2,
mask_generator=mask_generator, mask_generator=mask_generator,
logit_generator=logit_generator, logit_generator=logit_generator,
feature_grad_mult=feature_grad_mult, feature_grad_mult=feature_grad_mult,
) )
# initialize the model for pre-training
model.apply(_init_hubert_pretrain_model)
return model
def hubert_pretrain_base( def hubert_pretrain_base(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment