Commit 9ca82a0e authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

data per gpu change

Summary: Avoid loading entire data set per gpu to reduce memory footprint

Reviewed By: rutyrinott

Differential Revision: D13163548

fbshipit-source-id: 4ba717c8021ba5723d02225bae5782e2c3a18640
parent c37250ab
...@@ -98,6 +98,12 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -98,6 +98,12 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (
...@@ -158,6 +164,11 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -158,6 +164,11 @@ class IndexedInMemoryDataset(IndexedDataset):
if self.fix_lua_indexing: if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self): def __del__(self):
pass pass
......
...@@ -182,3 +182,10 @@ class MonolingualDataset(FairseqDataset): ...@@ -182,3 +182,10 @@ class MonolingualDataset(FairseqDataset):
order = [np.arange(len(self))] order = [np.arange(len(self))]
order.append(self.sizes) order.append(self.sizes)
return np.lexsort(order) return np.lexsort(order)
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
def prefetch(self, indices):
self.dataset.prefetch(indices)
...@@ -10,8 +10,10 @@ import math ...@@ -10,8 +10,10 @@ import math
import numpy as np import numpy as np
import torch import torch
from . import FairseqDataset
class TokenBlockDataset(torch.utils.data.Dataset):
class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks. """Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated. The blocks are fetched from the original tensor so no additional memory is allocated.
...@@ -29,27 +31,27 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -29,27 +31,27 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets include_targets: return next tokens as targets
""" """
def __init__(self, tokens, sizes, block_size, pad, eos, break_mode=None, include_targets=False): def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__() super().__init__()
self.dataset = ds
self.tokens = tokens
self.total_size = len(tokens)
self.pad = pad self.pad = pad
self.eos = eos self.eos = eos
self.include_targets = include_targets self.include_targets = include_targets
self.slice_indices = [] self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
length = math.ceil(len(tokens) / block_size) total_size = sum(sizes)
length = math.ceil(total_size / block_size)
def block_at(i): def block_at(i):
start = i * block_size start = i * block_size
end = min(start + block_size, len(tokens)) end = min(start + block_size, total_size)
return (start, end) return (start, end)
self.slice_indices = [block_at(i) for i in range(length)] self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -64,7 +66,6 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -64,7 +66,6 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0 curr = 0
for sz in sizes: for sz in sizes:
# skip samples with just 1 example (which would be just the eos token) # skip samples with just 1 example (which would be just the eos token)
...@@ -77,25 +78,43 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -77,25 +78,43 @@ class TokenBlockDataset(torch.utils.data.Dataset):
self.sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index): def __getitem__(self, index):
s, e = self.slice_indices[index] s, e = self.cache_index[index]
item = torch.LongTensor(self.tokens[s:e]) item = torch.from_numpy(self.cache[s:e]).long()
if self.include_targets: if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos) # target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first) # past target is rotated to the left by 2 (padded if its first)
if s == 0: if s == 0:
source = np.concatenate([[self.eos], self.tokens[0:e - 1]]) source = np.concatenate([[self.eos], self.cache[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.tokens[0:e - 2]]) past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]])
else: else:
source = self.tokens[s - 1:e - 1] source = self.cache[s - 1: e - 1]
if s == 1: if s == 1:
past_target = np.concatenate([[self.eos], self.tokens[0:e - 2]]) past_target = np.concatenate([[self.eos], self.cache[0:e - 2]])
else: else:
past_target = self.tokens[s - 2:e - 2] past_target = self.cache[s - 2:e - 2]
return torch.LongTensor(source), item, torch.LongTensor(past_target) return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long()
return item return item
def __len__(self): def __len__(self):
return len(self.slice_indices) return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property
def supports_prefetch(self):
return True
...@@ -9,12 +9,10 @@ import itertools ...@@ -9,12 +9,10 @@ import itertools
import numpy as np import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq.data import ( from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary MonolingualDataset, TokenBlockDataset, TruncatedDictionary,
) IndexedCachedDataset, IndexedDataset)
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -140,10 +138,8 @@ class LanguageModelingTask(FairseqTask): ...@@ -140,10 +138,8 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l] elif not self.args.raw_text and IndexedDataset.exists(path):
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): ds = IndexedDataset(path, fix_lua_indexing=True)
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
else: else:
if k > 0: if k > 0:
break break
...@@ -152,7 +148,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -152,7 +148,7 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append( loaded_datasets.append(
TokenBlockDataset( TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True, break_mode=self.args.sample_break_mode, include_targets=True,
)) ))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment