Commit 576b02b1 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[Example] Refactor BucketizeBatchSampler and HuBERTDataset (#2150)

Summary:
- Rename `BucketizeSampler` to `BucketizeBatchSampler`
- Fix bugs in `BucketizeBatchSampler`
- Adjust HuBERTDataset based on the latest `BucketizeBatchSampler`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2150

Reviewed By: mthrok

Differential Revision: D33689963

Pulled By: nateanl

fbshipit-source-id: 203764e9af5b7577ba08ebaa30ba5da3b67fb7e7
parent 984b169e
from .hubert_dataset import (
BucketizeSampler,
BucketizeBatchSampler,
CollateFnHubert,
HuBERTDataSet,
)
__all__ = [
"BucketizeSampler",
"BucketizeBatchSampler",
"CollateFnHubert",
"HuBERTDataSet",
]
import random
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
......@@ -9,8 +8,8 @@ from torch import Tensor
from torch.utils.data import BatchSampler, Dataset
class BucketizeSampler(BatchSampler):
"""Buketize sampler for data with different lengths to reduce number of paddings.
class BucketizeBatchSampler(BatchSampler):
"""Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
Args:
lengths (List[int]): The lengths of the samples in the dataset.
......@@ -22,13 +21,19 @@ class BucketizeSampler(BatchSampler):
max_token_count (int or None, optional): The max number of tokens in one mini-batch.
(Default: ``None``)
batch_size (int or None, optional): The number of samples in one mini-batch.
(Default: ``None``)
(Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default True)
(Default: True)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
Note:
``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two
should have value.
Note: If ``max_token_count`` is not ``None``, the ``batch_size`` couldn't be set since
the lengths of samples are unknown, the batch size may be different for different
mini-batches.
Note:
``drop_last`` is only valid when ``batch_size`` argument is given.
"""
def __init__(
......@@ -40,6 +45,7 @@ class BucketizeSampler(BatchSampler):
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = True,
drop_last: bool = False,
) -> None:
if max_len is None:
max_len = max(lengths)
......@@ -48,8 +54,13 @@ class BucketizeSampler(BatchSampler):
raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
if max_token_count is not None and batch_size is not None:
raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
if max_token_count is None and batch_size is None:
raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.")
if max_token_count is not None:
assert (
max_len <= max_token_count
), "The ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``."
# Filter out samples which are outside the bounds of [min_len, max_len]
# sort to minimize gap when bucketizing.
filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
if len(filtered_length_idx) == 0:
raise AssertionError("``lengths`` cannot be empty after filtering.")
......@@ -58,16 +69,16 @@ class BucketizeSampler(BatchSampler):
self.indices = [e[1] for e in sorted_filtered_length_idx]
self.max_token_count = max_token_count
self.batch_size = batch_size
self.buckets = self._get_buckets(self.lengths, self.indices, num_buckets, min_len, max_len)
self.shuffle = shuffle
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self.iter_list = []
self._update_iter_list()
def _get_buckets(
self, lengths: List[int], indices: List[int], num_buckets: int, min_len: int, max_len: int
) -> Dict[int, Tensor]:
def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset.
Args:
lengths (List[int]): The lengths of the samples in the dataset.
indices (List[int]): The indices of the samples in the original dataset.
num_buckets (int): The number of buckets.
min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
......@@ -77,57 +88,49 @@ class BucketizeSampler(BatchSampler):
the Tensor of corresponding sample indices.
"""
buckets = {}
boundaries = [min_len - 1]
interval = (max_len - min_len) // num_buckets
for i in range(1, num_buckets):
boundaries.append(min_len + i * interval)
boundaries.append(max_len + 1)
bucket_ids = torch.bucketize(torch.tensor(lengths), torch.tensor(boundaries))
for i in indices:
bucket_id = bucket_ids[i]
boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1)
bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries)
for i in range(bucket_ids.size(0)):
bucket_id = int(bucket_ids[i])
if bucket_id in buckets:
buckets[bucket_id].append(i)
else:
buckets[bucket_id] = [i]
for k in buckets:
if self.shuffle:
random.shuffle(buckets[k])
buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
buckets = {k: v for k, v in sorted(buckets.items())}
return buckets
def __iter__(self) -> Iterator[List[int]]:
iter_list = []
def _update_iter_list(self) -> None:
self.iter_list = []
total_len = 0
batch = []
if self.max_token_count:
for k in self.buckets.keys():
for i in range(self.buckets[k].size(0)):
index = self.buckets[k][i]
if total_len > self.max_token_count:
iter_list.append(batch)
batch = [index]
total_len = self.lengths[index]
else:
batch.append(index)
total_len += self.lengths[index]
else:
for k in self.buckets.keys():
for i in range(self.buckets[k].size(0)):
index = self.buckets[k][i]
if total_len == self.batch_size:
iter_list.append(batch)
batch = [index]
total_len = 1
else:
batch.append(index)
total_len += 1
for batch in iter_list:
yield batch
max_batch_size = self.max_token_count if self.max_token_count else self.batch_size
for k in self.buckets:
for i in range(self.buckets[k].size(0)):
index = int(self.buckets[k][i])
sample_length = self.lengths[index] if self.max_token_count else 1
if total_len + sample_length <= max_batch_size:
batch.append(self.indices[index])
total_len += sample_length
else:
self.iter_list.append(batch)
batch = [self.indices[index]]
total_len = sample_length
if len(batch) > 0 and (self.max_token_count or not self.drop_last):
self.iter_list.append(batch)
def __iter__(self) -> Iterator[List[int]]:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self._update_iter_list()
return iter(self.iter_list)
def __len__(self):
return len(self.data_source)
if self.batch_size or (self.max_token_count and not self.shuffle):
return len(self.iter_list)
class HuBERTDataSet(Dataset):
......@@ -137,8 +140,6 @@ class HuBERTDataSet(Dataset):
exp_dir (str or Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
min_sample (int): The minimum number of audio samples in the dataset. (Default: 32000)
max_sample (int): The maximum number of audio samples in the dataset. (Default: 250000)
"""
def __init__(
......@@ -146,13 +147,11 @@ class HuBERTDataSet(Dataset):
exp_dir: Union[str, Path],
dataset: str,
subset: str,
min_sample: int = 32000,
max_sample: int = 250000,
) -> None:
self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset, min_sample, max_sample)
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset)
......@@ -164,20 +163,16 @@ class HuBERTDataSet(Dataset):
tsv_dir: Path,
dataset: str,
subset: str,
min_sample: int,
max_sample: int,
) -> Tuple[List[Path], List[int], List[int]]:
"""Get the list of paths for iteration.
Args:
tsv_dir (Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
min_sample (int): The minimum number of audio samples in the dataset.
max_sample (int): The maximum number of audio samples in the dataset.
Returns:
(numpy.array) List of file paths.
(numpy.array) List of indices that qualify ``min_sample`` <= length <= ``max_sample``.
(numpy.array) List of indices.
(numpy.array) List of waveform lengths.
"""
f_ind_len_list = []
......@@ -187,9 +182,7 @@ class HuBERTDataSet(Dataset):
path, nsample = line.split("\t")
path = f"{root}/{path}"
nsample = int(nsample)
if min_sample <= nsample <= max_sample:
f_ind_len_list.append((path, index, nsample))
f_ind_len_list.sort(key=lambda x: x[2]) # sort the file lists by the sequence length
f_ind_len_list.append((path, index, nsample))
f_list, ind_list, len_list = [], [], []
for ele in f_ind_len_list:
f_list.append(ele[0])
......@@ -220,7 +213,7 @@ class HuBERTDataSet(Dataset):
Returns:
(np.array): The numpy arrary that contains the labels for each audio file.
"""
with open(label_dir / f"{dataset}_{subset}.pt") as f:
with open(label_dir / f"label_{subset}.pt") as f:
labels = [line.rstrip() for line in f]
labels = [labels[i] for i in self.ind_list]
return np.asarray(labels, dtype=np.string_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment