Commit 09639680 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix CollateFn in HuBERT pre-training recipe (#2296)

Summary:
- When cropping the waveform and corresponding label, we use the formula `torch.div(audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")` to align the audio start and label start indices. However, sometimes the value can be negative, which result in an empty label. The training example will hurt the performance after zero-padding (i.e., the labels are all zero for the input waveform).
This PR fixes the bug by checking if `label_start` is negative, and change it to zero if so.
- If `pad` is True, the `length` should be the length of each waveform instead of the max length. Fix it to make the model ignore the padding component in pre-training.

Pull Request resolved: https://github.com/pytorch/audio/pull/2296

Reviewed By: mthrok

Differential Revision: D36323217

Pulled By: nateanl

fbshipit-source-id: 1ffa71e39bbc0e8dee55c3b829911bc2e785b423
parent 595dc5d3
import math
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
...@@ -303,17 +304,57 @@ class HuBERTDataSet(Dataset): ...@@ -303,17 +304,57 @@ class HuBERTDataSet(Dataset):
return (waveform, label, length) return (waveform, label, length)
def _crop_audio_label(
waveform: Tensor,
label: Tensor,
length: Tensor,
num_frames: int,
rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
label (Tensor): The label Tensor with dimensions `(1, seq)`.
length (Tensor): The length Tensor with dimension `(1,)`.
num_frames (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
frame_offset = 0
waveform = waveform[0]
if waveform.size(0) > num_frames and rand_crop:
diff = waveform.size(0) - num_frames
frame_offset = torch.randint(diff, size=(1,))
elif waveform.size(0) < num_frames:
num_frames = waveform.size(0)
label_offset = max(math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1, 0)
num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
waveform = waveform[frame_offset : frame_offset + num_frames]
label = label[label_offset : label_offset + num_label]
length = num_frames
return waveform, label, length
class CollateFnHubert: class CollateFnHubert:
"""The collate class for HuBERT pre-training and fine-tuning. """The collate class for HuBERT pre-training and fine-tuning.
Args: Args:
feature_type (str): The type of features for KMeans clustering. feature_type (str): The type of features for KMeans clustering.
Options: [``mfcc``, ``hubert``]. Options: [``mfcc``, ``hubert``].
pad (bool): If ``pad`` is True, the waveforms and labels will be padded pad (bool): If ``True``, the waveforms and labels will be padded to the
to the max length in the mini-batch. If ``pad`` is False, the waveforms max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch. and labels will be cropped to the minimum length in the mini-batch.
(Default: False) (Default: False)
rand_crop (bool): if ``rand_crop`` is True, the starting index of the rand_crop (bool): if ``True``, the starting index of the waveform
waveform and label is random if the length is longer than the minimum and label is random if the length is longer than the minimum
length in the mini-batch. length in the mini-batch.
""" """
...@@ -327,7 +368,7 @@ class CollateFnHubert: ...@@ -327,7 +368,7 @@ class CollateFnHubert:
self.pad = pad self.pad = pad
self.rand_crop = rand_crop self.rand_crop = rand_crop
def __call__(self, batch: Tuple[Tensor, Tensor, int]) -> Tuple[Tensor, Tensor, Tensor]: def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Tuple[Tensor, Tensor, Tensor]:
""" """
Args: Args:
batch (List[Tuple(Tensor, Tensor, int)]): batch (List[Tuple(Tensor, Tensor, int)]):
...@@ -335,65 +376,34 @@ class CollateFnHubert: ...@@ -335,65 +376,34 @@ class CollateFnHubert:
Returns: Returns:
(Tuple(Tensor, Tensor, Tensor)): (Tuple(Tensor, Tensor, Tensor)):
The Tensor of waveforms of dimension `[batch, time]`. The Tensor of waveforms with dimensions `(batch, time)`.
The Tensor of labels of dimension `[batch, seq]`. The Tensor of labels with dimensions `(batch, seq)`.
The Tensor of audio lengths of dimension `[batch,]`. The Tensor of audio lengths with dimension `(batch,)`.
""" """
audio_sizes = [sample[0].shape[1] for sample in batch]
if self.pad: if self.pad:
audio_size = max(audio_sizes) num_frames = max([sample[0].shape[1] for sample in batch])
else: else:
audio_size = min(audio_sizes) num_frames = min([sample[0].shape[1] for sample in batch])
waveforms, labels, lengths = [], [], [] waveforms, labels, lengths = [], [], []
for sample in batch: for sample in batch:
waveform, label, length = sample waveform, label, length = sample
# The MFCC feature is 10ms per frame, while the HuBERT's transformer output
# is 20ms per frame. Downsample the KMeans label if it's generated by MFCC features.
if self.feature_type == "mfcc": if self.feature_type == "mfcc":
label = label[::2] label = label[::2]
waveform, label, length = self._collate_audio_label(waveform, label, length, audio_size, self.rand_crop) waveform, label, length = _crop_audio_label(waveform, label, length, num_frames, self.rand_crop)
waveforms.append(waveform) waveforms.append(waveform)
lengths.append(length) lengths.append(length)
labels.append(label) labels.append(label)
# make sure the shapes are the same if not apply zero-padding
data = torch.zeros(len(batch), audio_size) if not self.pad:
for i in range(len(waveforms)): assert all(
data[i][0 : waveforms[i].shape[1]] = waveforms[i][0] [waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
lengths = torch.tensor(lengths) ), "The dimensions of the waveforms should be identical in the same batch."
assert all(
[label.shape[0] == labels[0].shape[0] for label in labels]
), "The dimensions of the labels should be identical in the same batch."
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
return data, labels, lengths lengths = torch.tensor(lengths)
return waveforms, labels, lengths
def _collate_audio_label(
self,
waveform: Tensor,
label: Tensor,
length: Tensor,
audio_size: int,
rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor of dimension `[1, time]`.
label (Tensor): The label Tensor of dimension `[1, seq]`.
length (Tensor): The length Tensor of dimension `[1,]`.
audio_size (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
if waveform.shape[1] > audio_size:
diff = waveform.size(1) - audio_size
audio_start = torch.randint(diff, size=(1,)) if rand_crop else 0
label_start = torch.div(
audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor"
)
label_size = torch.div(audio_size - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")
waveform = waveform[:, audio_start : audio_start + audio_size]
label = label[label_start : label_start + label_size]
length = audio_size
return waveform, label, length
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "hubert"))
import torch
from dataset.hubert_dataset import _crop_audio_label
from parameterized import parameterized
from torchaudio.models import hubert_base
from torchaudio_unittest.common_utils import get_whitenoise, TorchaudioTestCase
class TestCropAudioLabel(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
(400,),
(800,),
]
)
def test_zero_offset(self, num_frames):
"""Test _crop_audio_label method with zero frame offset.
Given the ``num_frames`` argument, the method returns the first ``num_frames`` samples in the waveform,
the corresponding labels, and the length of the cropped waveform.
The cropped waveform should be identical to the first ``num_frames`` samples of original waveform.
The length of the cropped waveform should be identical to ``num_frames``.
The dimension of the labels should be identical to HuBERT transformer layer output frame dimension.
"""
sample_rate = 16000
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05)
length = waveform.shape[1]
label = torch.rand(50)
model = hubert_base()
waveform_out, label_out, length = _crop_audio_label(waveform, label, length, num_frames, rand_crop=False)
hubert_feat = model.extract_features(waveform_out.unsqueeze(0), num_layers=1)[0][0]
self.assertEqual(waveform_out.shape[0], num_frames, length)
self.assertEqual(waveform_out, waveform[0, :num_frames])
self.assertEqual(label_out.shape[0], hubert_feat.shape[1])
@parameterized.expand(
[
(400,),
(800,),
]
)
def test_rand_crop(self, num_frames):
"""Test _crop_audio_label method with random frame offset.
Given the ``num_frames`` argument, the method returns ``num_frames`` samples in the waveform
starting with random offset, the corresponding labels, and the length of the cropped waveform.
The length of the cropped waveform should be identical to ``num_frames``.
The dimension of the labels should be identical to HuBERT transformer layer output frame dimension.
"""
sample_rate = 16000
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05)
length = waveform.shape[1]
label = torch.rand(50)
model = hubert_base()
waveform_out, label_out, length = _crop_audio_label(waveform, label, length, num_frames, rand_crop=False)
hubert_feat = model.extract_features(waveform_out.unsqueeze(0), num_layers=1)[0][0]
self.assertEqual(waveform_out.shape[0], num_frames, length)
self.assertEqual(label_out.shape[0], hubert_feat.shape[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment