Commit 29153e27 authored by theweiho's avatar theweiho Committed by Myle Ott
Browse files

Update dataset code for use by https://github.com/pytorch/translate/pull/62 (#161)

parent 3ae97589
...@@ -40,11 +40,19 @@ def code(dtype): ...@@ -40,11 +40,19 @@ def code(dtype):
return k return k
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
class IndexedDataset(object): class IndexedDataset(object):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path): def __init__(self, path):
with open(path + '.idx', 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == b'TNTIDX\x00\x00' assert magic == b'TNTIDX\x00\x00'
version = f.read(8) version = f.read(8)
...@@ -58,7 +66,7 @@ class IndexedDataset(object): ...@@ -58,7 +66,7 @@ class IndexedDataset(object):
self.read_data(path) self.read_data(path)
def read_data(self, path): def read_data(self, path):
self.data_file = open(path + '.bin', 'rb', buffering=0) self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i): def check_index(self, i):
if i < 0 or i >= self.size: if i < 0 or i >= self.size:
...@@ -80,14 +88,17 @@ class IndexedDataset(object): ...@@ -80,14 +88,17 @@ class IndexedDataset(object):
@staticmethod @staticmethod
def exists(path): def exists(path):
return os.path.exists(path + '.idx') return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
class IndexedInMemoryDataset(IndexedDataset): class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory""" """Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path): def read_data(self, path):
self.data_file = open(path + '.bin', 'rb') self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer) self.data_file.readinto(self.buffer)
self.data_file.close() self.data_file.close()
......
...@@ -40,7 +40,8 @@ class Tokenizer: ...@@ -40,7 +40,8 @@ class Tokenizer:
dict.add_symbol(dict.eos_word) dict.add_symbol(dict.eos_word)
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line): def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False):
nseq, ntok = 0, 0 nseq, ntok = 0, 0
replaced = Counter() replaced = Counter()
...@@ -50,7 +51,15 @@ class Tokenizer: ...@@ -50,7 +51,15 @@ class Tokenizer:
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f: for line in f:
ids = Tokenizer.tokenize(line, dict, tokenize, add_if_not_exist=False, consumer=replaced_consumer) ids = Tokenizer.tokenize(
line=line,
dict=dict,
tokenize=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1 nseq += 1
consumer(ids) consumer(ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment