Commit 576b02b1 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[Example] Refactor BucketizeBatchSampler and HuBERTDataset (#2150)

Summary:
- Rename `BucketizeSampler` to `BucketizeBatchSampler`
- Fix bugs in `BucketizeBatchSampler`
- Adjust HuBERTDataset based on the latest `BucketizeBatchSampler`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2150

Reviewed By: mthrok

Differential Revision: D33689963

Pulled By: nateanl

fbshipit-source-id: 203764e9af5b7577ba08ebaa30ba5da3b67fb7e7
parent 984b169e
from .hubert_dataset import ( from .hubert_dataset import (
BucketizeSampler, BucketizeBatchSampler,
CollateFnHubert, CollateFnHubert,
HuBERTDataSet, HuBERTDataSet,
) )
__all__ = [ __all__ = [
"BucketizeSampler", "BucketizeBatchSampler",
"CollateFnHubert", "CollateFnHubert",
"HuBERTDataSet", "HuBERTDataSet",
] ]
import random
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
...@@ -9,8 +8,8 @@ from torch import Tensor ...@@ -9,8 +8,8 @@ from torch import Tensor
from torch.utils.data import BatchSampler, Dataset from torch.utils.data import BatchSampler, Dataset
class BucketizeSampler(BatchSampler): class BucketizeBatchSampler(BatchSampler):
"""Buketize sampler for data with different lengths to reduce number of paddings. """Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
Args: Args:
lengths (List[int]): The lengths of the samples in the dataset. lengths (List[int]): The lengths of the samples in the dataset.
...@@ -22,13 +21,19 @@ class BucketizeSampler(BatchSampler): ...@@ -22,13 +21,19 @@ class BucketizeSampler(BatchSampler):
max_token_count (int or None, optional): The max number of tokens in one mini-batch. max_token_count (int or None, optional): The max number of tokens in one mini-batch.
(Default: ``None``) (Default: ``None``)
batch_size (int or None, optional): The number of samples in one mini-batch. batch_size (int or None, optional): The number of samples in one mini-batch.
(Default: ``None``) (Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling. shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default True) (Default: True)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
Note:
``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two
should have value.
Note: If ``max_token_count`` is not ``None``, the ``batch_size`` couldn't be set since Note:
the lengths of samples are unknown, the batch size may be different for different ``drop_last`` is only valid when ``batch_size`` argument is given.
mini-batches.
""" """
def __init__( def __init__(
...@@ -40,6 +45,7 @@ class BucketizeSampler(BatchSampler): ...@@ -40,6 +45,7 @@ class BucketizeSampler(BatchSampler):
max_token_count: Optional[int] = None, max_token_count: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
shuffle: bool = True, shuffle: bool = True,
drop_last: bool = False,
) -> None: ) -> None:
if max_len is None: if max_len is None:
max_len = max(lengths) max_len = max(lengths)
...@@ -48,8 +54,13 @@ class BucketizeSampler(BatchSampler): ...@@ -48,8 +54,13 @@ class BucketizeSampler(BatchSampler):
raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``") raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
if max_token_count is not None and batch_size is not None: if max_token_count is not None and batch_size is not None:
raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.") raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
if max_token_count is None and batch_size is None:
raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.")
if max_token_count is not None:
assert (
max_len <= max_token_count
), "The ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``."
# Filter out samples which are outside the bounds of [min_len, max_len] # Filter out samples which are outside the bounds of [min_len, max_len]
# sort to minimize gap when bucketizing.
filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len] filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
if len(filtered_length_idx) == 0: if len(filtered_length_idx) == 0:
raise AssertionError("``lengths`` cannot be empty after filtering.") raise AssertionError("``lengths`` cannot be empty after filtering.")
...@@ -58,16 +69,16 @@ class BucketizeSampler(BatchSampler): ...@@ -58,16 +69,16 @@ class BucketizeSampler(BatchSampler):
self.indices = [e[1] for e in sorted_filtered_length_idx] self.indices = [e[1] for e in sorted_filtered_length_idx]
self.max_token_count = max_token_count self.max_token_count = max_token_count
self.batch_size = batch_size self.batch_size = batch_size
self.buckets = self._get_buckets(self.lengths, self.indices, num_buckets, min_len, max_len)
self.shuffle = shuffle self.shuffle = shuffle
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self.iter_list = []
self._update_iter_list()
def _get_buckets( def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
self, lengths: List[int], indices: List[int], num_buckets: int, min_len: int, max_len: int
) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset. """Generate buckets based on the dataset.
Args: Args:
lengths (List[int]): The lengths of the samples in the dataset. lengths (List[int]): The lengths of the samples in the dataset.
indices (List[int]): The indices of the samples in the original dataset.
num_buckets (int): The number of buckets. num_buckets (int): The number of buckets.
min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width. min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width. max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
...@@ -77,57 +88,49 @@ class BucketizeSampler(BatchSampler): ...@@ -77,57 +88,49 @@ class BucketizeSampler(BatchSampler):
the Tensor of corresponding sample indices. the Tensor of corresponding sample indices.
""" """
buckets = {} buckets = {}
boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1)
boundaries = [min_len - 1] bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries)
interval = (max_len - min_len) // num_buckets for i in range(bucket_ids.size(0)):
for i in range(1, num_buckets): bucket_id = int(bucket_ids[i])
boundaries.append(min_len + i * interval)
boundaries.append(max_len + 1)
bucket_ids = torch.bucketize(torch.tensor(lengths), torch.tensor(boundaries))
for i in indices:
bucket_id = bucket_ids[i]
if bucket_id in buckets: if bucket_id in buckets:
buckets[bucket_id].append(i) buckets[bucket_id].append(i)
else: else:
buckets[bucket_id] = [i] buckets[bucket_id] = [i]
for k in buckets: for k in buckets:
if self.shuffle:
random.shuffle(buckets[k])
buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int) buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
buckets = {k: v for k, v in sorted(buckets.items())}
return buckets return buckets
def __iter__(self) -> Iterator[List[int]]: def _update_iter_list(self) -> None:
iter_list = [] self.iter_list = []
total_len = 0 total_len = 0
batch = [] batch = []
if self.max_token_count: max_batch_size = self.max_token_count if self.max_token_count else self.batch_size
for k in self.buckets.keys(): for k in self.buckets:
for i in range(self.buckets[k].size(0)): for i in range(self.buckets[k].size(0)):
index = self.buckets[k][i] index = int(self.buckets[k][i])
if total_len > self.max_token_count: sample_length = self.lengths[index] if self.max_token_count else 1
iter_list.append(batch) if total_len + sample_length <= max_batch_size:
batch = [index] batch.append(self.indices[index])
total_len = self.lengths[index] total_len += sample_length
else: else:
batch.append(index) self.iter_list.append(batch)
total_len += self.lengths[index] batch = [self.indices[index]]
else: total_len = sample_length
for k in self.buckets.keys(): if len(batch) > 0 and (self.max_token_count or not self.drop_last):
for i in range(self.buckets[k].size(0)): self.iter_list.append(batch)
index = self.buckets[k][i]
if total_len == self.batch_size: def __iter__(self) -> Iterator[List[int]]:
iter_list.append(batch) if self.shuffle:
batch = [index] for k in self.buckets:
total_len = 1 self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
else: self._update_iter_list()
batch.append(index)
total_len += 1 return iter(self.iter_list)
for batch in iter_list:
yield batch
def __len__(self): def __len__(self):
return len(self.data_source) if self.batch_size or (self.max_token_count and not self.shuffle):
return len(self.iter_list)
class HuBERTDataSet(Dataset): class HuBERTDataSet(Dataset):
...@@ -137,8 +140,6 @@ class HuBERTDataSet(Dataset): ...@@ -137,8 +140,6 @@ class HuBERTDataSet(Dataset):
exp_dir (str or Path): The root directory of the ``.tsv`` file list. exp_dir (str or Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``]. dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``]. subset (str): The subset of the dataset. Options: [``train``, ``valid``].
min_sample (int): The minimum number of audio samples in the dataset. (Default: 32000)
max_sample (int): The maximum number of audio samples in the dataset. (Default: 250000)
""" """
def __init__( def __init__(
...@@ -146,13 +147,11 @@ class HuBERTDataSet(Dataset): ...@@ -146,13 +147,11 @@ class HuBERTDataSet(Dataset):
exp_dir: Union[str, Path], exp_dir: Union[str, Path],
dataset: str, dataset: str,
subset: str, subset: str,
min_sample: int = 32000,
max_sample: int = 250000,
) -> None: ) -> None:
self.exp_dir = Path(exp_dir) self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv" tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label" label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset, min_sample, max_sample) f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset) self.labels = self._load_labels(label_dir, dataset, subset)
...@@ -164,20 +163,16 @@ class HuBERTDataSet(Dataset): ...@@ -164,20 +163,16 @@ class HuBERTDataSet(Dataset):
tsv_dir: Path, tsv_dir: Path,
dataset: str, dataset: str,
subset: str, subset: str,
min_sample: int,
max_sample: int,
) -> Tuple[List[Path], List[int], List[int]]: ) -> Tuple[List[Path], List[int], List[int]]:
"""Get the list of paths for iteration. """Get the list of paths for iteration.
Args: Args:
tsv_dir (Path): The root directory of the ``.tsv`` file list. tsv_dir (Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``]. dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``]. subset (str): The subset of the dataset. Options: [``train``, ``valid``].
min_sample (int): The minimum number of audio samples in the dataset.
max_sample (int): The maximum number of audio samples in the dataset.
Returns: Returns:
(numpy.array) List of file paths. (numpy.array) List of file paths.
(numpy.array) List of indices that qualify ``min_sample`` <= length <= ``max_sample``. (numpy.array) List of indices.
(numpy.array) List of waveform lengths. (numpy.array) List of waveform lengths.
""" """
f_ind_len_list = [] f_ind_len_list = []
...@@ -187,9 +182,7 @@ class HuBERTDataSet(Dataset): ...@@ -187,9 +182,7 @@ class HuBERTDataSet(Dataset):
path, nsample = line.split("\t") path, nsample = line.split("\t")
path = f"{root}/{path}" path = f"{root}/{path}"
nsample = int(nsample) nsample = int(nsample)
if min_sample <= nsample <= max_sample: f_ind_len_list.append((path, index, nsample))
f_ind_len_list.append((path, index, nsample))
f_ind_len_list.sort(key=lambda x: x[2]) # sort the file lists by the sequence length
f_list, ind_list, len_list = [], [], [] f_list, ind_list, len_list = [], [], []
for ele in f_ind_len_list: for ele in f_ind_len_list:
f_list.append(ele[0]) f_list.append(ele[0])
...@@ -220,7 +213,7 @@ class HuBERTDataSet(Dataset): ...@@ -220,7 +213,7 @@ class HuBERTDataSet(Dataset):
Returns: Returns:
(np.array): The numpy arrary that contains the labels for each audio file. (np.array): The numpy arrary that contains the labels for each audio file.
""" """
with open(label_dir / f"{dataset}_{subset}.pt") as f: with open(label_dir / f"label_{subset}.pt") as f:
labels = [line.rstrip() for line in f] labels = [line.rstrip() for line in f]
labels = [labels[i] for i in self.ind_list] labels = [labels[i] for i in self.ind_list]
return np.asarray(labels, dtype=np.string_) return np.asarray(labels, dtype=np.string_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment