"vscode:/vscode.git/clone" did not exist on "86da45bc66ee7ca782d5f498b0cacb490051d0f6"
Unverified Commit 4f00d5ac authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix graph being duplicated in multi-GPU and CPU dataloader workers (#3760)

* fix shared memory issue

* oops

* add explanation

* add explanation
parent 3f138eba
......@@ -60,15 +60,10 @@ class SAGE(nn.Module):
return y
def train(rank, world_size, shared_memory_name, features, num_classes, split_idx):
def train(rank, world_size, graph, num_classes, split_idx):
torch.cuda.set_device(rank)
dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)
graph = dgl.hetero_from_shared_memory(shared_memory_name)
feat, labels = features
graph.ndata['feat'] = feat
graph.ndata['label'] = labels
model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
......@@ -132,9 +127,8 @@ def train(rank, world_size, shared_memory_name, features, num_classes, split_idx
if __name__ == '__main__':
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
shared_memory_name = 'shm' # can be any string
feat = graph.ndata['feat']
graph = graph.shared_memory(shared_memory_name)
graph.ndata['label'] = labels
graph.create_formats_() # must be called before mp.spawn().
split_idx = dataset.get_idx_split()
num_classes = dataset.num_classes
n_procs = 4
......@@ -142,4 +136,4 @@ if __name__ == '__main__':
# Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs
# and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
import torch.multiprocessing as mp
mp.spawn(train, args=(n_procs, shared_memory_name, (feat, labels), num_classes, split_idx), nprocs=n_procs)
mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs)
......@@ -33,7 +33,7 @@ def _set_python_exit_flag():
PYTHON_EXIT_STATUS = True
atexit.register(_set_python_exit_flag)
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '10'))
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30'))
class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys):
......@@ -615,9 +615,11 @@ class DataLoader(torch.utils.data.DataLoader):
raise ValueError(
'Expect graph and indices to be on the same device. '
'If you wish to use UVA sampling, please set use_uva=True.')
if self.graph.device.type == 'cuda':
if num_workers > 0:
if self.graph.device.type == 'cuda' and num_workers > 0:
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.')
if self.graph.device.type == 'cpu' and num_workers > 0:
# Instantiate all the formats if the number of workers is greater than 0.
self.graph.create_formats_()
# Check pin_prefetcher and use_prefetch_thread - should be only effective
# if performing CPU sampling but output device is CUDA
......@@ -666,10 +668,6 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_prefetch_thread = use_prefetch_thread
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None))
# Instantiate all the formats if the number of workers is greater than 0.
if num_workers > 0 and hasattr(self.graph, 'create_formats_'):
self.graph.create_formats_()
self.other_storages = {}
super().__init__(
......@@ -732,6 +730,14 @@ class NodeDataLoader(DataLoader):
graph and feature tensors into pinned memory.
Default: False.
.. warning::
Using UVA with multiple GPUs may crash with device mismatch errors with
older CUDA drivers. We have confirmed that CUDA driver 450.142 will
crash while 465.19 will work. Therefore we recommend you to upgrade your
CUDA driver if you wish to use UVA with multiple GPUs.
use_prefetch_thread : bool, optional
(Advanced option)
Spawns a new Python thread to perform feature slicing
......@@ -916,6 +922,14 @@ class EdgeDataLoader(DataLoader):
graph and feature tensors into pinned memory.
Default: False.
.. warning::
Using UVA with multiple GPUs may crash with device mismatch errors with
older CUDA drivers. We have confirmed that CUDA driver 450.142 will
crash while 465.19 will work. Therefore we recommend you to upgrade your
CUDA driver if you wish to use UVA with multiple GPUs.
batch_size : int, optional
drop_last : bool, optional
shuffle : bool, optional
......
......@@ -1370,9 +1370,21 @@ def _forking_rebuild(pk_state):
meta, arrays = pk_state
arrays = [F.to_dgl_nd(arr) for arr in arrays]
states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays)
return _CAPI_DGLHeteroForkingUnpickle(states)
graph_index = _CAPI_DGLHeteroForkingUnpickle(states)
graph_index._forking_pk_state = pk_state
return graph_index
def _forking_reduce(graph_index):
# Because F.from_dgl_nd(F.to_dgl_nd(x)) loses the information of shared memory
# file descriptor (because DLPack does not keep it), without caching the tensors
# PyTorch will allocate one shared memory region for every single worker.
# The downside is that if a graph_index is shared by forking and new formats are created
# afterwards, then sharing it again will not bring together the new formats. This case
# should be rare though because (1) DataLoader will create all the formats if num_workers > 0
# anyway, and (2) we require the users to explicitly create all formats before calling
# mp.spawn().
if hasattr(graph_index, '_forking_pk_state'):
return _forking_rebuild, (graph_index._forking_pk_state,)
states = _CAPI_DGLHeteroForkingPickle(graph_index)
arrays = [F.from_dgl_nd(arr) for arr in states.arrays]
# Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment