"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d7d6841406a2cef52da26fc58342e543b5cd9e1d"
Commit 439ead5a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Integrate with Apache Arrow/Plasma in-memory store for large datasets (#995)

Summary:
Datasets with many examples can generate very large indexes in TokenBlockDataset (and possibly elsewhere). When using `--num-workers>0` these indexes are pickled and transferred via a multiprocessing pipe, which is slow and can fail if the index grows beyond 4GB (~0.5B examples). Apache Arrow has an in-memory store called Plasma that will offload these arrays to shared memory, which both reduces duplication of the data and avoids needing to pickle.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/995

Differential Revision: D16697219

Pulled By: myleott

fbshipit-source-id: 1b679ee5b3d2726af54ff418f6159a3671173fb8
parent 72f9364c
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import tempfile
class PlasmaArray(object):
"""
Wrapper around numpy arrays that automatically moves the data to shared
memory upon serialization. This is particularly helpful when passing numpy
arrays through multiprocessing, so that data is not unnecessarily
duplicated or pickled.
"""
def __init__(self, array):
super().__init__()
self.array = array
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
self.object_id = None
self.path = None
# variables with underscores shouldn't be pickled
self._client = None
self._server = None
self._server_tmp = None
self._plasma = None
@property
def plasma(self):
if self._plasma is None and not self.disable:
try:
import pyarrow.plasma as plasma
self._plasma = plasma
except ImportError:
self._plasma = None
return self._plasma
def start_server(self):
if self.plasma is None or self._server is not None:
return
assert self.object_id is None
assert self.path is None
self._server_tmp = tempfile.NamedTemporaryFile()
self.path = self._server_tmp.name
self._server = subprocess.Popen([
'plasma_store',
'-m', str(int(1.05 * self.array.nbytes)),
'-s', self.path,
])
@property
def client(self):
if self._client is None:
assert self.path is not None
self._client = self.plasma.connect(self.path)
return self._client
def __getstate__(self):
if self.plasma is None:
return self.__dict__
if self.object_id is None:
self.start_server()
self.object_id = self.client.put(self.array)
state = self.__dict__.copy()
del state['array']
state['_client'] = None
state['_server'] = None
state['_server_tmp'] = None
state['_plasma'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
if self.plasma is None:
return
self.array = self.client.get(self.object_id)
def __del__(self):
if self._server is not None:
self._server.kill()
self._server = None
self._server_tmp.close()
self._server_tmp = None
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from . import FairseqDataset from fairseq.data import FairseqDataset, plasma_utils
class TokenBlockDataset(FairseqDataset): class TokenBlockDataset(FairseqDataset):
...@@ -43,7 +43,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -43,7 +43,7 @@ class TokenBlockDataset(FairseqDataset):
self.pad = pad self.pad = pad
self.eos = eos self.eos = eos
self.include_targets = include_targets self.include_targets = include_targets
self.slice_indices = [] slice_indices = []
assert len(dataset) == len(sizes) assert len(dataset) == len(sizes)
assert len(dataset) > 0 assert len(dataset) > 0
...@@ -57,7 +57,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -57,7 +57,7 @@ class TokenBlockDataset(FairseqDataset):
end = min(start + block_size, total_size) end = min(start + block_size, total_size)
return (start, end) return (start, end)
self.slice_indices = [block_at(i) for i in range(length)] slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == 'complete':
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
...@@ -67,11 +67,11 @@ class TokenBlockDataset(FairseqDataset): ...@@ -67,11 +67,11 @@ class TokenBlockDataset(FairseqDataset):
curr_size += sizes[sz_idx] curr_size += sizes[sz_idx]
sz_idx += 1 sz_idx += 1
else: else:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size tok_idx += curr_size
curr_size = 0 curr_size = 0
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'complete_doc': elif break_mode == 'complete_doc':
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
...@@ -85,32 +85,32 @@ class TokenBlockDataset(FairseqDataset): ...@@ -85,32 +85,32 @@ class TokenBlockDataset(FairseqDataset):
curr_size += sizes[sz_idx] curr_size += sizes[sz_idx]
sz_idx += 1 sz_idx += 1
else: else:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size tok_idx += curr_size
curr_size = 0 curr_size = 0
if sizes[sz_idx] == document_sep_len: if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx] tok_idx += sizes[sz_idx]
sz_idx += 1 sz_idx += 1
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
self.slice_indices = np.empty((len(sizes), 2), dtype=int) slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes): if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes) sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0) cumsum = torch.cumsum(sizes, dim=0)
self.slice_indices[0] = [0, sizes[0]] slice_indices[0] = [0, sizes[0]]
if len(cumsum) > 1: if len(cumsum) > 1:
self.slice_indices[1:] = cumsum.unfold(0, 2, 1) slice_indices[1:] = cumsum.unfold(0, 2, 1)
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError('Invalid break_mode: ' + break_mode)
self.slice_indices = np.array(self.slice_indices, dtype=int) slice_indices = np.array(slice_indices, dtype=int)
self.sizes = self.slice_indices[:, 1] - self.slice_indices[:, 0] self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices # build index mapping block indices to the underlying dataset indices
if break_mode == 'eos': if break_mode == 'eos':
# much faster version for eos break mode # much faster version for eos break mode
self.block_to_dataset_index = np.stack( block_to_dataset_index = np.stack(
[ [
np.arange(len(sizes)), # starting index in dataset np.arange(len(sizes)), # starting index in dataset
np.zeros(len(sizes), dtype=np.long), # starting offset within starting index np.zeros(len(sizes), dtype=np.long), # starting offset within starting index
...@@ -120,8 +120,8 @@ class TokenBlockDataset(FairseqDataset): ...@@ -120,8 +120,8 @@ class TokenBlockDataset(FairseqDataset):
) )
else: else:
ds = DatasetSearcher(sizes) ds = DatasetSearcher(sizes)
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int) block_to_dataset_index = np.empty((len(slice_indices), 3), dtype=int)
for i, (s, e) in enumerate(self.slice_indices): for i, (s, e) in enumerate(slice_indices):
ds.seek(s) ds.seek(s)
start_ds_idx = ds.current_index start_ds_idx = ds.current_index
start_offset = ds.current_offset start_offset = ds.current_offset
...@@ -129,12 +129,28 @@ class TokenBlockDataset(FairseqDataset): ...@@ -129,12 +129,28 @@ class TokenBlockDataset(FairseqDataset):
continue continue
ds.seek(e - 1) ds.seek(e - 1)
end_ds_idx = ds.current_index end_ds_idx = ds.current_index
self.block_to_dataset_index[i] = ( block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index start_offset, # starting offset within starting index
end_ds_idx, # ending index in dataset end_ds_idx, # ending index in dataset
) )
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
@property
def slice_indices(self):
return self._slice_indices.array
@property
def sizes(self):
return self._sizes.array
@property
def block_to_dataset_index(self):
return self._block_to_dataset_index.array
def __getitem__(self, index): def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([ buffer = torch.cat([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment