Commit 9ca82a0e authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

data per gpu change

Summary: Avoid loading entire data set per gpu to reduce memory footprint

Reviewed By: rutyrinott

Differential Revision: D13163548

fbshipit-source-id: 4ba717c8021ba5723d02225bae5782e2c3a18640
parent c37250ab
......@@ -98,6 +98,12 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod
def exists(path):
return (
......@@ -158,6 +164,11 @@ class IndexedInMemoryDataset(IndexedDataset):
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
......
......@@ -182,3 +182,10 @@ class MonolingualDataset(FairseqDataset):
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
def prefetch(self, indices):
self.dataset.prefetch(indices)
......@@ -10,8 +10,10 @@ import math
import numpy as np
import torch
from . import FairseqDataset
class TokenBlockDataset(torch.utils.data.Dataset):
class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
......@@ -29,27 +31,27 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.dataset = ds
self.pad = pad
self.eos = eos
self.include_targets = include_targets
self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes
if break_mode is None or break_mode == 'none':
length = math.ceil(len(tokens) / block_size)
total_size = sum(sizes)
length = math.ceil(total_size / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
end = min(start + block_size, total_size)
return (start, end)
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0
sz_idx = 0
curr_size = 0
......@@ -64,7 +66,6 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
......@@ -77,25 +78,43 @@ class TokenBlockDataset(torch.utils.data.Dataset):
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.slice_indices[index]
s, e = self.cache_index[index]
item = torch.LongTensor(self.tokens[s:e])
item = torch.from_numpy(self.cache[s:e]).long()
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
if s == 0:
source = np.concatenate([[self.eos], self.tokens[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.tokens[0:e - 2]])
source = np.concatenate([[self.eos], self.cache[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]])
else:
source = self.tokens[s - 1:e - 1]
source = self.cache[s - 1: e - 1]
if s == 1:
past_target = np.concatenate([[self.eos], self.tokens[0:e - 2]])
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]])
else:
past_target = self.tokens[s - 2:e - 2]
past_target = self.cache[s - 2:e - 2]
return torch.LongTensor(source), item, torch.LongTensor(past_target)
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long()
return item
def __len__(self):
return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property
def supports_prefetch(self):
return True
......@@ -9,12 +9,10 @@ import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary
)
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary,
IndexedCachedDataset, IndexedDataset)
from . import FairseqTask, register_task
......@@ -140,10 +138,8 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
elif not self.args.raw_text and IndexedDataset.exists(path):
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
if k > 0:
break
......@@ -152,7 +148,7 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment