"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3f70e3c4aede3d6412c3c9a73e92d6dc72a7232d"
Commit b1714c14 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'albert_data_loader' of...

Merge branch 'albert_data_loader' of ssh://gitlab-master.nvidia.com:12051/ADLR/megatron-lm into albert_data_loader
parents f51ceb7c 3f4bc91b
...@@ -14,6 +14,7 @@ from functools import lru_cache ...@@ -14,6 +14,7 @@ from functools import lru_cache
import os import os
import shutil import shutil
import struct import struct
from itertools import accumulate
import numpy as np import numpy as np
import torch import torch
...@@ -50,11 +51,11 @@ def make_builder(out_file, impl, vocab_size=None): ...@@ -50,11 +51,11 @@ def make_builder(out_file, impl, vocab_size=None):
return IndexedDatasetBuilder(out_file) return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False): def make_dataset(path, impl):
if impl == 'lazy' and IndexedDataset.exists(path): if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path): elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path): elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path) return MMapIndexedDataset(path)
return None return None
...@@ -114,10 +115,9 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -114,10 +115,9 @@ class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset""" """Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00' _HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path):
super().__init__() super().__init__()
self.path = path self.path = path
self.fix_lua_indexing = fix_lua_indexing
self.data_file = None self.data_file = None
self.read_index(path) self.read_index(path)
...@@ -150,19 +150,30 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -150,19 +150,30 @@ class IndexedDataset(torch.utils.data.Dataset):
if self.data_file: if self.data_file:
self.data_file.close() self.data_file.close()
@lru_cache(maxsize=8) #@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, idx):
if not self.data_file: if not self.data_file:
self.read_data(self.path) self.read_data(self.path)
self.check_index(i) if isinstance(idx, int):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] i = idx
a = np.empty(tensor_size, dtype=self.dtype) self.check_index(i)
self.data_file.seek(self.data_offsets[i] * self.element_size) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
self.data_file.readinto(a) a = np.empty(tensor_size, dtype=self.dtype)
item = torch.from_numpy(a).long() self.data_file.seek(self.data_offsets[i] * self.element_size)
if self.fix_lua_indexing: self.data_file.readinto(a)
item -= 1 # subtract 1 for 0-based indexing return a
return item elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self): def __len__(self):
return self._len return self._len
...@@ -186,8 +197,8 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -186,8 +197,8 @@ class IndexedDataset(torch.utils.data.Dataset):
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path):
super().__init__(path, fix_lua_indexing=fix_lua_indexing) super().__init__(path)
self.cache = None self.cache = None
self.cache_index = {} self.cache_index = {}
...@@ -219,17 +230,22 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -219,17 +230,22 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close() self.data_file.close()
self.data_file = None self.data_file = None
@lru_cache(maxsize=8) #@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, idx):
self.check_index(i) if isinstance(idx, int):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] i = idx
a = np.empty(tensor_size, dtype=self.dtype) self.check_index(i)
ptx = self.cache_index[i] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
np.copyto(a, self.cache[ptx: ptx + a.size]) a = np.empty(tensor_size, dtype=self.dtype)
item = torch.from_numpy(a).long() ptx = self.cache_index[i]
if self.fix_lua_indexing: np.copyto(a, self.cache[ptx: ptx + a.size])
item -= 1 # subtract 1 for 0-based indexing return a
return item elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
...@@ -434,14 +450,26 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -434,14 +450,26 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._index) return len(self._index)
@lru_cache(maxsize=8) #@lru_cache(maxsize=8)
def __getitem__(self, i): def __getitem__(self, idx):
ptr, size = self._index[i] if isinstance(idx, int):
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) ptr, size = self._index[idx]
if self._index.dtype != np.int64: np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
np_array = np_array.astype(np.int64) if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array)
return torch.from_numpy(np_array)
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
@property @property
def sizes(self): def sizes(self):
......
...@@ -44,7 +44,8 @@ class Encoder(object): ...@@ -44,7 +44,8 @@ class Encoder(object):
for sentence in Encoder.splitter.tokenize(text): for sentence in Encoder.splitter.tokenize(text):
tokens = Encoder.tokenizer.tokenize(sentence) tokens = Encoder.tokenizer.tokenize(sentence)
ids = Encoder.tokenizer.convert_tokens_to_ids(tokens) ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
doc_ids.append(ids) if len(ids) > 0:
doc_ids.append(ids)
return doc_ids, len(json_line) return doc_ids, len(json_line)
def main(): def main():
......
...@@ -18,16 +18,18 @@ def test_indexed_dataset(args): ...@@ -18,16 +18,18 @@ def test_indexed_dataset(args):
if ds.supports_prefetch: if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small) # just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds))) ds.prefetch(range(len(ds)))
for i in range(1): for i in range(len(ds.doc_idx)-1):
start = ds.doc_idx[i] start = ds.doc_idx[i]
end = ds.doc_idx[i+1] end = ds.doc_idx[i+1]
print(start, end) ids = ds[start:end]
for j in range(start, end): for s in ids:
ids = ds[j].data.tolist() assert len(s) > 0
print(ids) l = s.data.tolist()
tokens = tokenizer.convert_ids_to_tokens(ids) tokens = tokenizer.convert_ids_to_tokens(l)
print(tokens) for t in tokens:
print("******** END DOCUMENT **********") if '\n' in t:
print("Newline in string!")
print(i)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment