Commit d1dc66d9 authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

optimizations for token_block_dataset

Summary:
optimizing memory use of token_block_dataset by replacing python data structures with numpy arrays.
applying needed parts from D13498973, instead of rebasing it on changes

Reviewed By: edunov

Differential Revision: D13678485

fbshipit-source-id: c0c827a8b95834a6a5456476040ebdc8e42136d4
parent cefe3f8a
......@@ -40,7 +40,7 @@ class TokenBlockDataset(FairseqDataset):
self.slice_indices = []
assert len(dataset) == len(sizes)
sizes = np.array(sizes, dtype=int)
if break_mode is None or break_mode == 'none':
total_size = sum(sizes)
length = math.ceil(total_size / block_size)
......@@ -66,21 +66,23 @@ class TokenBlockDataset(FairseqDataset):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int)
curr = 0
for sz in sizes:
for i, sz in enumerate(sizes):
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
self.slice_indices[i] = (curr, curr + sz)
curr += sz
else:
raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int)
# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = []
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes:
for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
......@@ -91,11 +93,11 @@ class TokenBlockDataset(FairseqDataset):
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index.append((
self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
))
)
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment