Commit fd778091 authored by Vladislav Agafonov's avatar Vladislav Agafonov Committed by Facebook GitHub Bot
Browse files

Add Wav2Vec2DataModule in self_supervised_learning training recipe (#3081)

Summary:
Add `Wav2Vec2DataModule` in self_supervised_learning training recipe to support Wav2Vec2 pre-training.

Pull Request resolved: https://github.com/pytorch/audio/pull/3081

Reviewed By: mthrok

Differential Revision: D43579239

Pulled By: nateanl

fbshipit-source-id: 3e935eb9a18ef0259a58940ae466cbdc3baf8494
parent c532f35c
...@@ -2,4 +2,5 @@ from ._hubert_datamodule import HuBERTDataModule ...@@ -2,4 +2,5 @@ from ._hubert_datamodule import HuBERTDataModule
__all__ = [ __all__ = [
"HuBERTDataModule", "HuBERTDataModule",
"Wav2Vec2DataModule",
] ]
...@@ -324,15 +324,15 @@ class HuBERTDataSet(Dataset): ...@@ -324,15 +324,15 @@ class HuBERTDataSet(Dataset):
def _crop_audio_label( def _crop_audio_label(
waveform: Tensor, waveform: Tensor,
label: Tensor, label: Optional[Tensor],
length: Tensor, length: Tensor,
num_frames: int, num_frames: int,
rand_crop: bool, rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Optional[Tensor], Tensor]:
"""Collate the audio and label at the same time. """Collate the audio and label at the same time.
Args: Args:
waveform (Tensor): The waveform Tensor with dimensions `(1, time)`. waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
label (Tensor): The label Tensor with dimensions `(1, seq)`. label (Tensor, optional): The label Tensor with dimensions `(1, seq)`.
length (Tensor): The length Tensor with dimension `(1,)`. length (Tensor): The length Tensor with dimension `(1,)`.
num_frames (int): The final length of the waveform. num_frames (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the rand_crop (bool): if ``rand_crop`` is True, the starting index of the
...@@ -340,7 +340,7 @@ def _crop_audio_label( ...@@ -340,7 +340,7 @@ def _crop_audio_label(
length in the mini-batch. length in the mini-batch.
Returns: Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform, (Tuple(Tensor, (Tensor, optional), Tensor)): Returns the Tensors for the waveform,
label, and the waveform length. label, and the waveform length.
""" """
kernel_size = 25 kernel_size = 25
...@@ -353,13 +353,15 @@ def _crop_audio_label( ...@@ -353,13 +353,15 @@ def _crop_audio_label(
frame_offset = torch.randint(diff, size=(1,)) frame_offset = torch.randint(diff, size=(1,))
elif waveform.size(0) < num_frames: elif waveform.size(0) < num_frames:
num_frames = waveform.size(0) num_frames = waveform.size(0)
label_offset = max(
math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1, if label is not None:
0, label_offset = max(
) math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1,
num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1 0,
)
num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
label = label[label_offset : label_offset + num_label]
waveform = waveform[frame_offset : frame_offset + num_frames] waveform = waveform[frame_offset : frame_offset + num_frames]
label = label[label_offset : label_offset + num_label]
length = num_frames length = num_frames
return waveform, label, length return waveform, label, length
...@@ -430,3 +432,57 @@ class CollateFnHubert: ...@@ -430,3 +432,57 @@ class CollateFnHubert:
lengths = torch.tensor(lengths) lengths = torch.tensor(lengths)
batch = Batch((waveforms, labels, lengths), (labels,)) batch = Batch((waveforms, labels, lengths), (labels,))
return batch return batch
class CollateFnWav2Vec2:
"""The collate class for Wav2Vec2 pre-training and fine-tuning.
Args:
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
pad: bool = False,
rand_crop: bool = True,
) -> None:
self.pad = pad
self.rand_crop = rand_crop
def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Dict:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
The list of tuples that contains the waveforms, labels, and audio lengths.
Returns:
Dictionary
"input": Tuple of waveforms and lengths.
waveforms Tensor with dimensions `(batch, time)`.
lengths Tensor with dimension `(batch,)`.
"label": None
"""
if self.pad:
num_frames = max([sample[0].shape[1] for sample in batch])
else:
num_frames = min([sample[0].shape[1] for sample in batch])
waveforms, lengths = [], []
for sample in batch:
waveform, length = sample
waveform, _, length = _crop_audio_label(waveform, None, length, num_frames, self.rand_crop)
waveforms.append(waveform)
lengths.append(length)
# make sure the shapes are the same if not apply zero-padding
if not self.pad:
assert all(
[waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
), "The dimensions of the waveforms should be identical in the same batch."
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
lengths = torch.tensor(lengths)
batch = Batch((waveforms, lengths), (None,))
return batch
import torch
from pytorch_lightning import LightningDataModule
from torchaudio.datasets.librispeech import LIBRISPEECH
from ._utils import BucketizeBatchSampler, CollateFnWav2Vec2, DistributedBatchSampler
class Wav2Vec2DataModule(LightningDataModule):
librispeech_cls = LIBRISPEECH
def __init__(
self,
*,
dataset_path,
seconds_per_batch,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.dataset_path, url="train-clean-360"),
self.librispeech_cls(self.dataset_path, url="train-clean-100"),
self.librispeech_cls(self.dataset_path, url="train-other-500"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
def val_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment