Commit d1dc66d9 authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

optimizations for token_block_dataset

Summary:
optimizing memory use of token_block_dataset by replacing python data structures with numpy arrays.
applying needed parts from D13498973, instead of rebasing it on changes

Reviewed By: edunov

Differential Revision: D13678485

fbshipit-source-id: c0c827a8b95834a6a5456476040ebdc8e42136d4
parent cefe3f8a
...@@ -40,7 +40,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -40,7 +40,7 @@ class TokenBlockDataset(FairseqDataset):
self.slice_indices = [] self.slice_indices = []
assert len(dataset) == len(sizes) assert len(dataset) == len(sizes)
sizes = np.array(sizes, dtype=int)
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
total_size = sum(sizes) total_size = sum(sizes)
length = math.ceil(total_size / block_size) length = math.ceil(total_size / block_size)
...@@ -66,21 +66,23 @@ class TokenBlockDataset(FairseqDataset): ...@@ -66,21 +66,23 @@ class TokenBlockDataset(FairseqDataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int)
curr = 0 curr = 0
for sz in sizes: for i, sz in enumerate(sizes):
# skip samples with just 1 example (which would be just the eos token) # skip samples with just 1 example (which would be just the eos token)
if sz > 1: if sz > 1:
self.slice_indices.append((curr, curr + sz)) self.slice_indices[i] = (curr, curr + sz)
curr += sz curr += sz
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int)
# build index mapping block indices to the underlying dataset indices # build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = [] self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0 ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes: for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s
if ds_remaining == 0: if ds_remaining == 0:
ds_idx += 1 ds_idx += 1
ds_remaining = sizes[ds_idx] ds_remaining = sizes[ds_idx]
...@@ -91,11 +93,11 @@ class TokenBlockDataset(FairseqDataset): ...@@ -91,11 +93,11 @@ class TokenBlockDataset(FairseqDataset):
ds_idx += 1 ds_idx += 1
ds_remaining = sizes[ds_idx] ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume ds_remaining -= to_consume
self.block_to_dataset_index.append(( self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index start_offset, # starting offset within starting index
ds_idx, # ending index in dataset ds_idx, # ending index in dataset
)) )
assert ds_remaining == 0 assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1 assert ds_idx == len(self.dataset) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment