"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0368483b61ad690aef5b0d92611270868c0ce8ea"
Commit 8207f263 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc dataset improvements

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/911

Differential Revision: D16536559

Pulled By: myleott

fbshipit-source-id: 7fe495054ce5b7658b1d3a43eca38c5858360236
parent abc13e28
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
import bisect import bisect
import numpy as np import numpy as np
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset from . import FairseqDataset
...@@ -50,7 +51,10 @@ class ConcatDataset(FairseqDataset): ...@@ -50,7 +51,10 @@ class ConcatDataset(FairseqDataset):
def collater(self, samples): def collater(self, samples):
# For now only supports datasets with same underlying collater implementations # For now only supports datasets with same underlying collater implementations
return self.datasets[0].collater(samples) if hasattr(self.datasets[0], 'collater'):
return self.datasets[0].collater(samples)
else:
return default_collate(samples)
def size(self, idx: int): def size(self, idx: int):
""" """
......
...@@ -52,6 +52,9 @@ class Dictionary(object): ...@@ -52,6 +52,9 @@ class Dictionary(object):
"""Returns the number of symbols in the dictionary""" """Returns the number of symbols in the dictionary"""
return len(self.symbols) return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym): def index(self, sym):
"""Returns the index of the specified symbol""" """Returns the index of the specified symbol"""
assert isinstance(sym, str) assert isinstance(sym, str)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import numpy as np
import torch.utils.data import torch.utils.data
...@@ -41,7 +42,7 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -41,7 +42,7 @@ class FairseqDataset(torch.utils.data.Dataset):
def ordered_indices(self): def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based """Return an ordered list of indices. Batches will be constructed based
on this order.""" on this order."""
raise NotImplementedError return np.arange(len(self))
@property @property
def supports_prefetch(self): def supports_prefetch(self):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from functools import lru_cache
import os import os
import shutil import shutil
import struct import struct
...@@ -146,6 +147,7 @@ class IndexedDataset(FairseqDataset): ...@@ -146,6 +147,7 @@ class IndexedDataset(FairseqDataset):
if self.data_file: if self.data_file:
self.data_file.close() self.data_file.close()
@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, i):
if not self.data_file: if not self.data_file:
self.read_data(self.path) self.read_data(self.path)
...@@ -214,6 +216,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -214,6 +216,7 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close() self.data_file.close()
self.data_file = None self.data_file = None
@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
...@@ -255,6 +258,7 @@ class IndexedRawTextDataset(FairseqDataset): ...@@ -255,6 +258,7 @@ class IndexedRawTextDataset(FairseqDataset):
if i < 0 or i >= self.size: if i < 0 or i >= self.size:
raise IndexError('index out of range') raise IndexError('index out of range')
@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
return self.tokens_list[i] return self.tokens_list[i]
...@@ -429,6 +433,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -429,6 +433,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def sizes(self): def sizes(self):
return self._sizes return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, i):
return self._pointers[i], self._sizes[i] return self._pointers[i], self._sizes[i]
...@@ -466,6 +471,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -466,6 +471,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._index) return len(self._index)
@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, i):
ptr, size = self._index[i] ptr, size = self._index[i]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment