Commit 8c16529b authored by Elijah Rippeth's avatar Elijah Rippeth Committed by Facebook GitHub Bot
Browse files

[Example] abstracts BucketizeSampler to be usable outside of HuBERT example. (#2147)

Summary:
This PR:

- Replaces the `data_source` with `lengths`
- Adds a `shuffle` argument to decide whether to shuffle the samples in the buckets.
- Add `max_len` and `min_len` to filter out samples that are > max_len or < min_len.

cc nateanl

Pull Request resolved: https://github.com/pytorch/audio/pull/2147

Reviewed By: carolineechen

Differential Revision: D33454369

Pulled By: nateanl

fbshipit-source-id: 3835169ec7f808f8dd9650e7f183f79091efe886
parent b73f5d67
import random import random
from pathlib import Path from pathlib import Path
from typing import ( from typing import Dict, Iterator, List, Optional, Tuple, Union
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset, BatchSampler from torch.utils.data import BatchSampler, Dataset
class BucketizeSampler(BatchSampler): class BucketizeSampler(BatchSampler):
"""Buketize sampler for data with different lengths to reduce number of paddings. """Buketize sampler for data with different lengths to reduce number of paddings.
Args: Args:
data_source (Dataset): The dataset to sample lengths (List[int]): The lengths of the samples in the dataset.
num_buckets (int): The number of buckets to split the data samples. num_buckets (int): The number of buckets to split the data samples.
min_len (int, optional): The minimum sample lengths to keep.
(Default: 0)
max_len (int or None, optional): The maximum sample lengths to keep. Inferred if not provided.
(Default ``None``)
max_token_count (int or None, optional): The max number of tokens in one mini-batch. max_token_count (int or None, optional): The max number of tokens in one mini-batch.
(Default: ``None``) (Default: ``None``)
batch_size (int or None, optional): The number of samples in one mini-batch. batch_size (int or None, optional): The number of samples in one mini-batch.
(Default: ``None``) (Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default True)
Note: If ``max_token_count`` is not ``None``, the ``batch_size`` couldn't be set since Note: If ``max_token_count`` is not ``None``, the ``batch_size`` couldn't be set since
the lengths of samples are unknown, the batch size may be different for different the lengths of samples are unknown, the batch size may be different for different
...@@ -34,46 +33,66 @@ class BucketizeSampler(BatchSampler): ...@@ -34,46 +33,66 @@ class BucketizeSampler(BatchSampler):
def __init__( def __init__(
self, self,
data_source: Dataset, lengths: List[int],
num_buckets: int, num_buckets: int,
min_len: int = 0,
max_len: Optional[int] = None,
max_token_count: Optional[int] = None, max_token_count: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
shuffle: bool = True,
) -> None: ) -> None:
if max_len is None:
max_len = max(lengths)
if not (0 <= min_len <= max_len):
raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
if max_token_count is not None and batch_size is not None: if max_token_count is not None and batch_size is not None:
raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.") raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
self.data_source = data_source # Filter out samples which are outside the bounds of [min_len, max_len]
# sort to minimize gap when bucketizing.
filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
if len(filtered_length_idx) == 0:
raise AssertionError("``lengths`` cannot be empty after filtering.")
sorted_filtered_length_idx = sorted(filtered_length_idx, key=lambda x: x[0])
self.lengths = [e[0] for e in sorted_filtered_length_idx]
self.indices = [e[1] for e in sorted_filtered_length_idx]
self.max_token_count = max_token_count self.max_token_count = max_token_count
self.batch_size = batch_size self.batch_size = batch_size
self.buckets = self._get_buckets(self.data_source, num_buckets) self.buckets = self._get_buckets(self.lengths, self.indices, num_buckets, min_len, max_len)
self.shuffle = shuffle
def _get_buckets(self, data_source: Dataset, num_buckets: int) -> Dict[int, Tensor]: def _get_buckets(
self, lengths: List[int], indices: List[int], num_buckets: int, min_len: int, max_len: int
) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset. """Generate buckets based on the dataset.
Args: Args:
data_source (Dataset): The dataset object to bucketize. lengths (List[int]): The lengths of the samples in the dataset.
indices (List[int]): The indices of the samples in the original dataset.
num_buckets (int): The number of buckets. num_buckets (int): The number of buckets.
min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
Returns: Returns:
(dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is (dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is
the Tensor of corresponding sample indices. the Tensor of corresponding sample indices.
""" """
buckets = {} buckets = {}
len_list = data_source.len_list
min_len, max_len = min(len_list), max(len_list)
boundaries = [min_len - 1] boundaries = [min_len - 1]
interval = (max_len - min_len) // num_buckets interval = (max_len - min_len) // num_buckets
for i in range(1, num_buckets): for i in range(1, num_buckets):
boundaries.append(min_len + i * interval) boundaries.append(min_len + i * interval)
boundaries.append(max_len + 1) boundaries.append(max_len + 1)
bucket_ids = torch.bucketize(torch.tensor(len_list), torch.tensor(boundaries)) bucket_ids = torch.bucketize(torch.tensor(lengths), torch.tensor(boundaries))
for i, _ in enumerate(len_list): for i in indices:
bucket_id = bucket_ids[i] bucket_id = bucket_ids[i]
if bucket_id in buckets: if bucket_id in buckets:
buckets[bucket_id].append(i) buckets[bucket_id].append(i)
else: else:
buckets[bucket_id] = [i] buckets[bucket_id] = [i]
for k in buckets: for k in buckets:
random.shuffle(buckets[k]) if self.shuffle:
random.shuffle(buckets[k])
buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int) buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
return buckets return buckets
...@@ -81,7 +100,6 @@ class BucketizeSampler(BatchSampler): ...@@ -81,7 +100,6 @@ class BucketizeSampler(BatchSampler):
iter_list = [] iter_list = []
total_len = 0 total_len = 0
batch = [] batch = []
len_list = self.data_source.len_list
if self.max_token_count: if self.max_token_count:
for k in self.buckets.keys(): for k in self.buckets.keys():
for i in range(self.buckets[k].size(0)): for i in range(self.buckets[k].size(0)):
...@@ -89,10 +107,10 @@ class BucketizeSampler(BatchSampler): ...@@ -89,10 +107,10 @@ class BucketizeSampler(BatchSampler):
if total_len > self.max_token_count: if total_len > self.max_token_count:
iter_list.append(batch) iter_list.append(batch)
batch = [index] batch = [index]
total_len = len_list[index] total_len = self.lengths[index]
else: else:
batch.append(index) batch.append(index)
total_len += len_list[index] total_len += self.lengths[index]
else: else:
for k in self.buckets.keys(): for k in self.buckets.keys():
for i in range(self.buckets[k].size(0)): for i in range(self.buckets[k].size(0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment