Commit 8207f263 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc dataset improvements

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/911

Differential Revision: D16536559

Pulled By: myleott

fbshipit-source-id: 7fe495054ce5b7658b1d3a43eca38c5858360236
parent abc13e28
......@@ -8,6 +8,7 @@
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
......@@ -50,7 +51,10 @@ class ConcatDataset(FairseqDataset):
def collater(self, samples):
# For now only supports datasets with same underlying collater implementations
return self.datasets[0].collater(samples)
if hasattr(self.datasets[0], 'collater'):
return self.datasets[0].collater(samples)
else:
return default_collate(samples)
def size(self, idx: int):
"""
......
......@@ -52,6 +52,9 @@ class Dictionary(object):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
......
......@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch.utils.data
......@@ -41,7 +42,7 @@ class FairseqDataset(torch.utils.data.Dataset):
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
raise NotImplementedError
return np.arange(len(self))
@property
def supports_prefetch(self):
......
......@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from functools import lru_cache
import os
import shutil
import struct
......@@ -146,6 +147,7 @@ class IndexedDataset(FairseqDataset):
if self.data_file:
self.data_file.close()
@lru_cache(maxsize=8)
def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
......@@ -214,6 +216,7 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close()
self.data_file = None
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
......@@ -255,6 +258,7 @@ class IndexedRawTextDataset(FairseqDataset):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
......@@ -429,6 +433,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
......@@ -466,6 +471,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self._index)
@lru_cache(maxsize=8)
def __getitem__(self, i):
ptr, size = self._index[i]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment