Commit 09639680 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix CollateFn in HuBERT pre-training recipe (#2296)

Summary:
- When cropping the waveform and corresponding label, we use the formula `torch.div(audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")` to align the audio start and label start indices. However, sometimes the value can be negative, which result in an empty label. The training example will hurt the performance after zero-padding (i.e., the labels are all zero for the input waveform).
This PR fixes the bug by checking if `label_start` is negative, and change it to zero if so.
- If `pad` is True, the `length` should be the length of each waveform instead of the max length. Fix it to make the model ignore the padding component in pre-training.

Pull Request resolved: https://github.com/pytorch/audio/pull/2296

Reviewed By: mthrok

Differential Revision: D36323217

Pulled By: nateanl

fbshipit-source-id: 1ffa71e39bbc0e8dee55c3b829911bc2e785b423
parent 595dc5d3
import math
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
......@@ -303,17 +304,57 @@ class HuBERTDataSet(Dataset):
return (waveform, label, length)
def _crop_audio_label(
waveform: Tensor,
label: Tensor,
length: Tensor,
num_frames: int,
rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
label (Tensor): The label Tensor with dimensions `(1, seq)`.
length (Tensor): The length Tensor with dimension `(1,)`.
num_frames (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
frame_offset = 0
waveform = waveform[0]
if waveform.size(0) > num_frames and rand_crop:
diff = waveform.size(0) - num_frames
frame_offset = torch.randint(diff, size=(1,))
elif waveform.size(0) < num_frames:
num_frames = waveform.size(0)
label_offset = max(math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1, 0)
num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
waveform = waveform[frame_offset : frame_offset + num_frames]
label = label[label_offset : label_offset + num_label]
length = num_frames
return waveform, label, length
class CollateFnHubert:
"""The collate class for HuBERT pre-training and fine-tuning.
Args:
feature_type (str): The type of features for KMeans clustering.
Options: [``mfcc``, ``hubert``].
pad (bool): If ``pad`` is True, the waveforms and labels will be padded
to the max length in the mini-batch. If ``pad`` is False, the waveforms
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
......@@ -327,7 +368,7 @@ class CollateFnHubert:
self.pad = pad
self.rand_crop = rand_crop
def __call__(self, batch: Tuple[Tensor, Tensor, int]) -> Tuple[Tensor, Tensor, Tensor]:
def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
......@@ -335,65 +376,34 @@ class CollateFnHubert:
Returns:
(Tuple(Tensor, Tensor, Tensor)):
The Tensor of waveforms of dimension `[batch, time]`.
The Tensor of labels of dimension `[batch, seq]`.
The Tensor of audio lengths of dimension `[batch,]`.
The Tensor of waveforms with dimensions `(batch, time)`.
The Tensor of labels with dimensions `(batch, seq)`.
The Tensor of audio lengths with dimension `(batch,)`.
"""
audio_sizes = [sample[0].shape[1] for sample in batch]
if self.pad:
audio_size = max(audio_sizes)
num_frames = max([sample[0].shape[1] for sample in batch])
else:
audio_size = min(audio_sizes)
num_frames = min([sample[0].shape[1] for sample in batch])
waveforms, labels, lengths = [], [], []
for sample in batch:
waveform, label, length = sample
# The MFCC feature is 10ms per frame, while the HuBERT's transformer output
# is 20ms per frame. Downsample the KMeans label if it's generated by MFCC features.
if self.feature_type == "mfcc":
label = label[::2]
waveform, label, length = self._collate_audio_label(waveform, label, length, audio_size, self.rand_crop)
waveform, label, length = _crop_audio_label(waveform, label, length, num_frames, self.rand_crop)
waveforms.append(waveform)
lengths.append(length)
labels.append(label)
data = torch.zeros(len(batch), audio_size)
for i in range(len(waveforms)):
data[i][0 : waveforms[i].shape[1]] = waveforms[i][0]
lengths = torch.tensor(lengths)
# make sure the shapes are the same if not apply zero-padding
if not self.pad:
assert all(
[waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
), "The dimensions of the waveforms should be identical in the same batch."
assert all(
[label.shape[0] == labels[0].shape[0] for label in labels]
), "The dimensions of the labels should be identical in the same batch."
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
return data, labels, lengths
def _collate_audio_label(
self,
waveform: Tensor,
label: Tensor,
length: Tensor,
audio_size: int,
rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor of dimension `[1, time]`.
label (Tensor): The label Tensor of dimension `[1, seq]`.
length (Tensor): The length Tensor of dimension `[1,]`.
audio_size (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
if waveform.shape[1] > audio_size:
diff = waveform.size(1) - audio_size
audio_start = torch.randint(diff, size=(1,)) if rand_crop else 0
label_start = torch.div(
audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor"
)
label_size = torch.div(audio_size - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")
waveform = waveform[:, audio_start : audio_start + audio_size]
label = label[label_start : label_start + label_size]
length = audio_size
return waveform, label, length
lengths = torch.tensor(lengths)
return waveforms, labels, lengths
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "hubert"))
import torch
from dataset.hubert_dataset import _crop_audio_label
from parameterized import parameterized
from torchaudio.models import hubert_base
from torchaudio_unittest.common_utils import get_whitenoise, TorchaudioTestCase
class TestCropAudioLabel(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
(400,),
(800,),
]
)
def test_zero_offset(self, num_frames):
"""Test _crop_audio_label method with zero frame offset.
Given the ``num_frames`` argument, the method returns the first ``num_frames`` samples in the waveform,
the corresponding labels, and the length of the cropped waveform.
The cropped waveform should be identical to the first ``num_frames`` samples of original waveform.
The length of the cropped waveform should be identical to ``num_frames``.
The dimension of the labels should be identical to HuBERT transformer layer output frame dimension.
"""
sample_rate = 16000
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05)
length = waveform.shape[1]
label = torch.rand(50)
model = hubert_base()
waveform_out, label_out, length = _crop_audio_label(waveform, label, length, num_frames, rand_crop=False)
hubert_feat = model.extract_features(waveform_out.unsqueeze(0), num_layers=1)[0][0]
self.assertEqual(waveform_out.shape[0], num_frames, length)
self.assertEqual(waveform_out, waveform[0, :num_frames])
self.assertEqual(label_out.shape[0], hubert_feat.shape[1])
@parameterized.expand(
[
(400,),
(800,),
]
)
def test_rand_crop(self, num_frames):
"""Test _crop_audio_label method with random frame offset.
Given the ``num_frames`` argument, the method returns ``num_frames`` samples in the waveform
starting with random offset, the corresponding labels, and the length of the cropped waveform.
The length of the cropped waveform should be identical to ``num_frames``.
The dimension of the labels should be identical to HuBERT transformer layer output frame dimension.
"""
sample_rate = 16000
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05)
length = waveform.shape[1]
label = torch.rand(50)
model = hubert_base()
waveform_out, label_out, length = _crop_audio_label(waveform, label, length, num_frames, rand_crop=False)
hubert_feat = model.extract_features(waveform_out.unsqueeze(0), num_layers=1)[0][0]
self.assertEqual(waveform_out.shape[0], num_frames, length)
self.assertEqual(label_out.shape[0], hubert_feat.shape[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment