Commit d8a5a11d authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix _init_hubert_pretrain_model (#2886)

Summary:
address https://github.com/pytorch/audio/issues/2885

In `_init_hubert_pretrain_model ` method which initialize the hubert pretrain models, `kaiming_normal_` should be applied on `ConvLayerBlock` instead of `LayerNorm` layer. This PR fixes it and adds more unit tests.

Pull Request resolved: https://github.com/pytorch/audio/pull/2886

Reviewed By: hwangjeff

Differential Revision: D41713801

Pulled By: nateanl

fbshipit-source-id: ed199baf7504d06bbf2d31c522ae708a75426a2d
parent 55e9978a
import math
import os import os
from typing import Tuple from typing import Tuple
...@@ -7,6 +8,9 @@ from parameterized import parameterized ...@@ -7,6 +8,9 @@ from parameterized import parameterized
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
hubert_base, hubert_base,
hubert_large, hubert_large,
hubert_pretrain_base,
hubert_pretrain_large,
hubert_pretrain_xlarge,
hubert_xlarge, hubert_xlarge,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
...@@ -47,6 +51,15 @@ factory_funcs_wavlm = parameterized.expand( ...@@ -47,6 +51,15 @@ factory_funcs_wavlm = parameterized.expand(
name_func=_name_func, name_func=_name_func,
) )
factory_funcs_hubert_pretrain = parameterized.expand(
[
(hubert_pretrain_base,),
(hubert_pretrain_large,),
(hubert_pretrain_xlarge,),
],
name_func=_name_func,
)
class TestWav2Vec2Model(TorchaudioTestCase): class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype): def _smoke_test(self, model, device, dtype):
...@@ -398,3 +411,127 @@ class TestWavLMModel(TorchaudioTestCase): ...@@ -398,3 +411,127 @@ class TestWavLMModel(TorchaudioTestCase):
def test_quantize_torchscript(self, factory_func): def test_quantize_torchscript(self, factory_func):
"""Quantized WavLM model should be scriptable""" """Quantized WavLM model should be scriptable"""
self._test_quantize_torchscript(factory_func(aux_num_out=32)) self._test_quantize_torchscript(factory_func(aux_num_out=32))
def _compute_label_frame(audio_frame: int) -> int:
"""Compute number of frames in the label tensor based on
the number of frames in the audio tensor."""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
label_frame = math.floor((audio_frame - kernel_size * sample_rate) / (stride * sample_rate)) + 1
return label_frame
class TestHuBERTPretrainModel(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
labels = torch.randint(
low=0,
high=100,
size=[
batch_size,
_compute_label_frame(num_frames),
],
device=device,
)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
model(waveforms, labels, lengths)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = hubert_pretrain_base()
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = hubert_pretrain_base()
self._smoke_test(model, torch.device("cuda"), dtype)
def _feature_extractor_test(self, model):
batch_size, num_frames = 3, 1024
model = model.wav2vec2
model.eval()
num_layers = len(model.encoder.transformer.layers)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
# Not providing num_layers returns all the intermediate features from
# tranformer layers
all_features, lengths_ = model.extract_features(waveforms, lengths, num_layers=None)
assert len(all_features) == num_layers
for features in all_features:
assert features.ndim == 3
assert features.shape[0] == batch_size
assert lengths_.shape == torch.Size([batch_size])
# Limiting the number of layers to `l`.
for l in range(1, num_layers + 1):
features, lengths_ = model.extract_features(waveforms, lengths, num_layers=l)
assert len(features) == l
for i in range(l):
self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size])
@factory_funcs_hubert_pretrain
def test_extract_feature(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func())
def _test_quantize_smoke_test(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.wav2vec2.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
waveforms = torch.randn(batch_size, num_frames)
labels = torch.randint(
low=0,
high=100,
size=[
batch_size,
_compute_label_frame(num_frames),
],
)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
_, _, _ = quantized(waveforms, labels, lengths)
@factory_funcs_hubert_pretrain
@skipIfNoQengine
def test_quantize(self, factory_func):
"""HuBERTPretrainModel should support basic quantization"""
self._test_quantize_smoke_test(factory_func())
...@@ -683,7 +683,7 @@ def hubert_xlarge( ...@@ -683,7 +683,7 @@ def hubert_xlarge(
def _init_hubert_pretrain_model(module): def _init_hubert_pretrain_model(module):
if isinstance(module, components.LayerNorm): if isinstance(module, components.ConvLayerBlock):
torch.nn.init.kaiming_normal_(module.conv.weight) torch.nn.init.kaiming_normal_(module.conv.weight)
elif isinstance(module, components.ConvolutionalPositionalEmbedding): elif isinstance(module, components.ConvolutionalPositionalEmbedding):
# normalize the weight to normal distribution. # normalize the weight to normal distribution.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment