Commit 17cef3f6 authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Black formatting for multi_corpus_sampled_dataset.py (#638)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/638

RT

Reviewed By: liezl200

Differential Revision: D14967268

fbshipit-source-id: 2da361497743d90a841fdbf2a50085136c70b468
parent 8776928c
...@@ -28,10 +28,10 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -28,10 +28,10 @@ class MultiCorpusSampledDataset(FairseqDataset):
""" """
def __init__( def __init__(
self, self,
datasets: Dict[str, FairseqDataset], datasets: Dict[str, FairseqDataset],
sampling_dist: str = 'uniform', sampling_dist: str = "uniform",
default_key: str = '' default_key: str = "",
): ):
super().__init__() super().__init__()
assert isinstance(datasets, OrderedDict) assert isinstance(datasets, OrderedDict)
...@@ -62,34 +62,26 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -62,34 +62,26 @@ class MultiCorpusSampledDataset(FairseqDataset):
if self._ordered_indices is None: if self._ordered_indices is None:
self._ordered_indices = OrderedDict( self._ordered_indices = OrderedDict(
[ [
( (key, dataset.ordered_indices())
key, dataset.ordered_indices()
)
for key, dataset in self.datasets.items() for key, dataset in self.datasets.items()
] ]
) )
return np.arange(len(self)) return np.arange(len(self))
def _map_index_to_dataset( def _map_index_to_dataset(self, key: int, index: int):
self,
key: int,
index: int
):
""" """
Different underlying datasets have different lengths. In order to ensure Different underlying datasets have different lengths. In order to ensure
we are not accessing an index outside the range of the current dataset we are not accessing an index outside the range of the current dataset
size, we wrap around. This function should be called after we have size, we wrap around. This function should be called after we have
created an ordering for this and all underlying datasets. created an ordering for this and all underlying datasets.
""" """
assert self._ordered_indices is not None, \ assert (
'Must call MultiCorpusSampledDataset.ordered_indices() first' self._ordered_indices is not None
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
mapped_index = index % len(self.datasets[key]) mapped_index = index % len(self.datasets[key])
return self._ordered_indices[key][mapped_index] return self._ordered_indices[key][mapped_index]
def __getitem__( def __getitem__(self, index: int):
self,
index: int
):
""" """
Get the item associated with index from each underlying dataset. Get the item associated with index from each underlying dataset.
Since index is in the range of [0, TotalNumInstances], we need to Since index is in the range of [0, TotalNumInstances], we need to
...@@ -97,17 +89,12 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -97,17 +89,12 @@ class MultiCorpusSampledDataset(FairseqDataset):
""" """
return OrderedDict( return OrderedDict(
[ [
( (key, dataset[self._map_index_to_dataset(key, index)])
key, dataset[self._map_index_to_dataset(key, index)]
)
for key, dataset in self.datasets.items() for key, dataset in self.datasets.items()
] ]
) )
def collater( def collater(self, samples: List[Dict]):
self,
samples: List[Dict]
):
""" """
Generate a mini-batch for this dataset. Generate a mini-batch for this dataset.
To convert this into a regular mini-batch we use the following To convert this into a regular mini-batch we use the following
...@@ -118,35 +105,26 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -118,35 +105,26 @@ class MultiCorpusSampledDataset(FairseqDataset):
if len(samples) == 0: if len(samples) == 0:
return None return None
if self.sampling_dist == 'uniform': if self.sampling_dist == "uniform":
candidates = list(self.datasets.keys()) candidates = list(self.datasets.keys())
selected_key = np.random.choice(candidates, 1).item() selected_key = np.random.choice(candidates, 1).item()
selected_samples = [ selected_samples = [sample[selected_key] for sample in samples]
sample[selected_key]
for sample in samples
]
return self.datasets[selected_key].collater(selected_samples) return self.datasets[selected_key].collater(selected_samples)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Specified sampling is currently not Implemented." "Specified sampling is currently not Implemented."
) )
def get_dummy_batch( def get_dummy_batch(self, num_tokens: int, max_positions: int):
self,
num_tokens: int,
max_positions: int,
):
""" """
Return a dummy batch with a given number of tokens. Assumes that the Return a dummy batch with a given number of tokens. Assumes that the
max_positions specified is the same for all underlying datasets. max_positions specified is the same for all underlying datasets.
""" """
return self.datasets[self.default_key].get_dummy_batch( return self.datasets[self.default_key].get_dummy_batch(
num_tokens, max_positions) num_tokens, max_positions
)
def num_tokens( def num_tokens(self, index: int):
self,
index: int
):
""" """
Return an example's length (number of tokens), used for batching. Here Return an example's length (number of tokens), used for batching. Here
we return the max across all examples at index across all underlying we return the max across all examples at index across all underlying
...@@ -157,10 +135,7 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -157,10 +135,7 @@ class MultiCorpusSampledDataset(FairseqDataset):
for key, dataset in self.datasets.items() for key, dataset in self.datasets.items()
) )
def size( def size(self, index: int):
self,
index: int
):
""" """
Return an example's size as a float or tuple. Here we return the max Return an example's size as a float or tuple. Here we return the max
across all underlying datasets. This value is used when filtering a across all underlying datasets. This value is used when filtering a
...@@ -174,14 +149,12 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -174,14 +149,12 @@ class MultiCorpusSampledDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return all( return all(
getattr(dataset, 'supports_prefetch', False) getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values() for dataset in self.datasets.values()
) )
def prefetch(self, indices): def prefetch(self, indices):
for key, dataset in self.datasets.items(): for key, dataset in self.datasets.items():
dataset.prefetch( dataset.prefetch(
[ [self._map_index_to_dataset(key, index) for index in indices]
self._map_index_to_dataset(key, index) for index in indices
]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment