Commit 61697cab authored by Jared Casper's avatar Jared Casper
Browse files

Add ability to take a contiguous slice to indexed_dataset.

Also removed the fix_lua_indexing stuff from fairseq.
parent 6140718f
......@@ -14,6 +14,7 @@ from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate
import numpy as np
import torch
......@@ -50,11 +51,11 @@ def make_builder(out_file, impl, vocab_size=None):
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False):
def make_dataset(path, impl):
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
return None
......@@ -114,10 +115,9 @@ class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path, fix_lua_indexing=False):
def __init__(self, path):
super().__init__()
self.path = path
self.fix_lua_indexing = fix_lua_indexing
self.data_file = None
self.read_index(path)
......@@ -150,19 +150,30 @@ class IndexedDataset(torch.utils.data.Dataset):
if self.data_file:
self.data_file.close()
@lru_cache(maxsize=8)
def __getitem__(self, i):
#@lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
......@@ -186,8 +197,8 @@ class IndexedDataset(torch.utils.data.Dataset):
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
......@@ -219,17 +230,22 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close()
self.data_file = None
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
#@lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
......@@ -434,14 +450,26 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self._index)
@lru_cache(maxsize=8)
def __getitem__(self, i):
ptr, size = self._index[i]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array)
#@lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array)
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
@property
def sizes(self):
......
......@@ -18,14 +18,14 @@ def test_indexed_dataset(args):
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
for i in range(1):
for i in range(2):
start = ds.doc_idx[i]
end = ds.doc_idx[i+1]
print(start, end)
for j in range(start, end):
ids = ds[j].data.tolist()
print(ids)
tokens = tokenizer.convert_ids_to_tokens(ids)
ids = ds[start:end]
for s in ids:
l = s.data.tolist()
print(l)
tokens = tokenizer.convert_ids_to_tokens(l)
print(tokens)
print("******** END DOCUMENT **********")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment