Unverified Commit a7fe461c authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Doc] Single-machine Multi-GPU node classification tutorial (#2976)

* multi-GPU node classification tutorial

* Update 2_node_classification.py

* fixes

* elaborate a bit more

* address comments

* address comments
parent 6c59fee9
...@@ -198,11 +198,13 @@ from sphinx_gallery.sorting import FileNameSortKey ...@@ -198,11 +198,13 @@ from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = ['../../tutorials/blitz', examples_dirs = ['../../tutorials/blitz',
'../../tutorials/large', '../../tutorials/large',
'../../tutorials/dist', '../../tutorials/dist',
'../../tutorials/models'] # path to find sources '../../tutorials/models',
'../../tutorials/multi'] # path to find sources
gallery_dirs = ['tutorials/blitz/', gallery_dirs = ['tutorials/blitz/',
'tutorials/large/', 'tutorials/large/',
'tutorials/dist/', 'tutorials/dist/',
'tutorials/models/'] # path to generate docs 'tutorials/models/',
'tutorials/multi/'] # path to generate docs
reference_url = { reference_url = {
'dgl' : None, 'dgl' : None,
'numpy': 'http://docs.scipy.org/doc/numpy/', 'numpy': 'http://docs.scipy.org/doc/numpy/',
......
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
import inspect import inspect
import math
import torch as th import torch as th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
from ...distributed import DistGraph from ...distributed import DistGraph
from ...distributed import DistDataLoader from ...distributed import DistDataLoader
from ...ndarray import NDArray as DGLNDArray from ...ndarray import NDArray as DGLNDArray
from ... import backend as F from ... import backend as F
from ...base import DGLError
class _ScalarDataBatcherIter: class _ScalarDataBatcherIter:
def __init__(self, dataset, batch_size, drop_last): def __init__(self, dataset, batch_size, drop_last):
...@@ -42,16 +45,44 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset): ...@@ -42,16 +45,44 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset):
is passed in. is passed in.
""" """
def __init__(self, dataset, shuffle=False, batch_size=1, def __init__(self, dataset, shuffle=False, batch_size=1,
drop_last=False): drop_last=False, use_ddp=False, ddp_seed=0):
super(_ScalarDataBatcher).__init__() super(_ScalarDataBatcher).__init__()
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.drop_last = drop_last self.drop_last = drop_last
self.use_ddp = use_ddp
if use_ddp:
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
self.seed = ddp_seed
self.epoch = 0
# The following code (and the idea of cross-process shuffling with the same seed)
# comes from PyTorch. See torch/utils/data/distributed.py for details.
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any sample, since the dataset will be split evenly.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
self.total_size = self.num_samples * self.num_replicas
def __iter__(self): def __iter__(self):
if self.use_ddp:
return self._iter_ddp()
else:
return self._iter_non_ddp()
def _divide_by_worker(self, dataset):
worker_info = th.utils.data.get_worker_info() worker_info = th.utils.data.get_worker_info()
dataset = self.dataset
if worker_info: if worker_info:
# worker gets only a fraction of the dataset # worker gets only a fraction of the dataset
chunk_size = dataset.shape[0] // worker_info.num_workers chunk_size = dataset.shape[0] // worker_info.num_workers
...@@ -62,6 +93,11 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset): ...@@ -62,6 +93,11 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset):
end == dataset.shape[0] end == dataset.shape[0]
dataset = dataset[start:end] dataset = dataset[start:end]
return dataset
def _iter_non_ddp(self):
dataset = self._divide_by_worker(self.dataset)
if self.shuffle: if self.shuffle:
# permute the dataset # permute the dataset
perm = th.randperm(dataset.shape[0], device=dataset.device) perm = th.randperm(dataset.shape[0], device=dataset.device)
...@@ -69,9 +105,40 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset): ...@@ -69,9 +105,40 @@ class _ScalarDataBatcher(th.utils.data.IterableDataset):
return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last) return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
def _iter_ddp(self):
# The following code (and the idea of cross-process shuffling with the same seed)
# comes from PyTorch. See torch/utils/data/distributed.py for details.
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = th.Generator()
g.manual_seed(self.seed + self.epoch)
indices = th.randperm(len(self.dataset), generator=g)
else:
indices = th.arange(len(self.dataset))
if not self.drop_last:
# add extra samples to make it evenly divisible
indices = th.cat([indices, indices[:(self.total_size - indices.shape[0])]])
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert indices.shape[0] == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert indices.shape[0] == self.num_samples
# Dividing by worker is our own stuff.
dataset = self._divide_by_worker(self.dataset[indices])
return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
def __len__(self): def __len__(self):
return (self.dataset.shape[0] + (0 if self.drop_last else self.batch_size - 1)) // \ num_samples = self.num_samples if self.use_ddp else self.dataset.shape[0]
self.batch_size return (num_samples + (0 if self.drop_last else self.batch_size - 1)) // self.batch_size
def set_epoch(self, epoch):
"""Set epoch number for distributed training."""
self.epoch = epoch
def _remove_kwargs_dist(kwargs): def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs: if 'num_workers' in kwargs:
...@@ -276,9 +343,14 @@ class NodeDataLoader: ...@@ -276,9 +343,14 @@ class NodeDataLoader:
use_ddp : boolean, optional use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each If True, tells the DataLoader to split the training set for each
participating process appropriately using participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`. :class:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`. Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -318,7 +390,7 @@ class NodeDataLoader: ...@@ -318,7 +390,7 @@ class NodeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(NodeCollator).args collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, block_sampler, device='cpu', use_ddp=False, **kwargs): def __init__(self, g, nids, block_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -340,6 +412,7 @@ class NodeDataLoader: ...@@ -340,6 +412,7 @@ class NodeDataLoader:
else: else:
self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs) self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs)
dataset = self.collator.dataset dataset = self.collator.dataset
use_scalar_batcher = False
if th.device(device) != th.device('cpu'): if th.device(device) != th.device('cpu'):
# Only use the '_ScalarDataBatcher' when for the GPU, as it # Only use the '_ScalarDataBatcher' when for the GPU, as it
...@@ -363,18 +436,24 @@ class NodeDataLoader: ...@@ -363,18 +436,24 @@ class NodeDataLoader:
dataset = _ScalarDataBatcher(dataset, dataset = _ScalarDataBatcher(dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last,
use_ddp=use_ddp,
ddp_seed=ddp_seed)
# need to overwrite things that will be handled by the batcher # need to overwrite things that will be handled by the batcher
dataloader_kwargs['batch_size'] = None dataloader_kwargs['batch_size'] = None
dataloader_kwargs['shuffle'] = False dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False dataloader_kwargs['drop_last'] = False
use_scalar_batcher = True
self.scalar_batcher = dataset
self.use_ddp = use_ddp self.use_ddp = use_ddp
if use_ddp: self.use_scalar_batcher = use_scalar_batcher
if use_ddp and not use_scalar_batcher:
self.dist_sampler = DistributedSampler( self.dist_sampler = DistributedSampler(
dataset, dataset,
shuffle=dataloader_kwargs['shuffle'], shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last']) drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
...@@ -418,6 +497,9 @@ class NodeDataLoader: ...@@ -418,6 +497,9 @@ class NodeDataLoader:
The epoch number. The epoch number.
""" """
if self.use_ddp: if self.use_ddp:
if self.use_scalar_batcher:
self.scalar_batcher.set_epoch(epoch)
else:
self.dist_sampler.set_epoch(epoch) self.dist_sampler.set_epoch(epoch)
else: else:
raise DGLError('set_epoch is only available when use_ddp is True.') raise DGLError('set_epoch is only available when use_ddp is True.')
...@@ -502,6 +584,11 @@ class EdgeDataLoader: ...@@ -502,6 +584,11 @@ class EdgeDataLoader:
epoch number, as recommended by PyTorch. epoch number, as recommended by PyTorch.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`. Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -620,7 +707,7 @@ class EdgeDataLoader: ...@@ -620,7 +707,7 @@ class EdgeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(EdgeCollator).args collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, block_sampler, device='cpu', use_ddp=False, **kwargs): def __init__(self, g, eids, block_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -640,7 +727,8 @@ class EdgeDataLoader: ...@@ -640,7 +727,8 @@ class EdgeDataLoader:
self.dist_sampler = DistributedSampler( self.dist_sampler = DistributedSampler(
dataset, dataset,
shuffle=dataloader_kwargs['shuffle'], shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last']) drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
...@@ -692,6 +780,17 @@ class GraphDataLoader: ...@@ -692,6 +780,17 @@ class GraphDataLoader:
collate_fn : Function, default is None collate_fn : Function, default is None
The customized collate function. Will use the default collate The customized collate function. Will use the default collate
function if not given. function if not given.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:class:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -720,7 +819,7 @@ class GraphDataLoader: ...@@ -720,7 +819,7 @@ class GraphDataLoader:
""" """
collator_arglist = inspect.getfullargspec(GraphCollator).args collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, use_ddp=False, **kwargs): def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -739,7 +838,8 @@ class GraphDataLoader: ...@@ -739,7 +838,8 @@ class GraphDataLoader:
self.dist_sampler = DistributedSampler( self.dist_sampler = DistributedSampler(
dataset, dataset,
shuffle=dataloader_kwargs['shuffle'], shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last']) drop_last=dataloader_kwargs['drop_last'],
seed=ddp_seed)
dataloader_kwargs['shuffle'] = False dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler dataloader_kwargs['sampler'] = self.dist_sampler
......
...@@ -58,6 +58,11 @@ labels = labels.to(device) ...@@ -58,6 +58,11 @@ labels = labels.to(device)
# That’s the core behind this tutorial. We will explore it more in detail with # That’s the core behind this tutorial. We will explore it more in detail with
# a complete example below. # a complete example below.
# #
# .. note::
#
# See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
# from PyTorch for general multi-GPU training with ``DistributedDataParallel``.
#
# Distributed Process Group Initialization # Distributed Process Group Initialization
# ---------------------------------------- # ----------------------------------------
# #
...@@ -222,9 +227,16 @@ def main(rank, world_size, dataset, seed=0): ...@@ -222,9 +227,16 @@ def main(rank, world_size, dataset, seed=0):
############################################################################### ###############################################################################
# Finally we load the dataset and launch the processes. # Finally we load the dataset and launch the processes.
# #
# .. note::
#
# You will need to use ``dgl.multiprocessing`` instead of the Python
# ``multiprocessing`` package. ``dgl.multiprocessing`` is identical to
# Python’s built-in ``multiprocessing`` except that it handles the
# subtleties between forking and multithreading in Python.
#
if __name__ == '__main__': if __name__ == '__main__':
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from dgl.data import GINDataset from dgl.data import GINDataset
......
"""
Single Machine Multi-GPU Minibatch Node Classification
======================================================
In this tutorial, you will learn how to use multiple GPUs in training a
graph neural network (GNN) for node classification.
(Time estimate: 8 minutes)
This tutorial assumes that you have read the :doc:`Training GNN with Neighbor
Sampling for Node Classification <../large/L1_large_node_classification>`
tutorial. It also assumes that you know the basics of training general
models with multi-GPU with ``DistributedDataParallel``.
.. note::
See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
from PyTorch for general multi-GPU training with ``DistributedDataParallel``. Also,
see the first section of :doc:`the multi-GPU graph classification
tutorial <1_graph_classification>`
for an overview of using ``DistributedDataParallel`` with DGL.
"""
######################################################################
# Loading Dataset
# ---------------
#
# OGB already prepared the data as a ``DGLGraph`` object. The following code is
# copy-pasted from the :doc:`Training GNN with Neighbor Sampling for Node
# Classification <../large/L1_large_node_classification>`
# tutorial.
#
import dgl
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import sklearn.metrics
dataset = DglNodePropPredDataset('ogbn-arxiv')
graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0]
node_features = graph.ndata['feat']
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test'] # Test node IDs, not used in the tutorial though.
######################################################################
# Defining Model
# --------------
#
# The model will be again identical to the :doc:`Training GNN with Neighbor
# Sampling for Node Classification <../large/L1_large_node_classification>`
# tutorial.
#
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
self.h_feats = h_feats
def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst))
return h
######################################################################
# Defining Training Procedure
# ---------------------------
#
# The training procedure will be slightly different from what you saw
# previously, in the sense that you will need to
#
# * Initialize a distributed training context with ``torch.distributed``.
# * Wrap your model with ``torch.nn.parallel.DistributedDataParallel``.
# * Add a ``use_ddp=True`` argument to the DGL dataloader you wish to run
# together with DDP.
#
# You will also need to wrap the training loop inside a function so that
# you can spawn subprocesses to run it.
#
def run(proc_id, devices):
# Initialize distributed training context.
dev_id = devices[proc_id]
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip='127.0.0.1', master_port='12345')
if torch.cuda.device_count() < 1:
device = torch.device('cpu')
torch.distributed.init_process_group(
backend='gloo', init_method=dist_init_method, world_size=len(devices), rank=proc_id)
else:
torch.cuda.set_device(dev_id)
device = torch.device('cuda:' + str(dev_id))
torch.distributed.init_process_group(
backend='nccl', init_method=dist_init_method, world_size=len(devices), rank=proc_id)
# Define training and validation dataloader, copied from the previous tutorial
# but with one line of difference: use_ddp to enable distributed data parallel
# data loading.
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
# The following arguments are specific to NodeDataLoader.
graph, # The graph
train_nids, # The node IDs to iterate over in minibatches
sampler, # The neighbor sampler
device=device, # Put the sampled MFGs on CPU or GPU
use_ddp=True, # Make it work with distributed data parallel
# The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Per-device batch size.
# The effective batch size is this number times the number of GPUs.
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes
)
valid_dataloader = dgl.dataloading.NodeDataLoader(
graph, valid_nids, sampler,
device=device,
use_ddp=False,
batch_size=1024,
shuffle=False,
drop_last=False,
num_workers=0,
)
model = Model(num_features, 128, num_classes).to(device)
# Wrap the model with distributed data parallel module.
if device == torch.device('cpu'):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None, output_device=None)
else:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)
# Define optimizer
opt = torch.optim.Adam(model.parameters())
best_accuracy = 0
best_model_path = './model.pt'
# Copied from previous tutorial with changes highlighted.
for epoch in range(10):
train_dataloader.set_epoch(epoch) # <--- necessary for dataloader with DDP.
model.train()
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
# feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat']
labels = mfgs[-1].dstdata['label']
predictions = model(mfgs, inputs)
loss = F.cross_entropy(predictions, labels)
opt.zero_grad()
loss.backward()
opt.step()
accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
model.eval()
# Evaluate on only the first GPU.
if proc_id == 0:
predictions = []
labels = []
with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
for input_nodes, output_nodes, mfgs in tq:
inputs = mfgs[0].srcdata['feat']
labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)
accuracy = sklearn.metrics.accuracy_score(labels, predictions)
print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
if best_accuracy < accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), best_model_path)
# Note that this tutorial does not train the whole model to the end.
break
######################################################################
# Spawning Trainer Processes
# --------------------------
#
# A typical scenario for multi-GPU training with DDP is to replicate the
# model once per GPU, and spawn one trainer process per GPU.
#
# PyTorch tutorials recommend using ``multiprocessing.spawn`` to spawn
# multiple processes. This however is undesirable for training node
# classification or link prediction models on a single large graph,
# especially on Linux. The reason is that a single large graph itself may
# take a lot of memory, and ``mp.spawn`` will duplicate all objects in the
# program, including the large graph. Consequently, the large graph will
# be duplicated as many times as the number of GPUs.
#
# To alleviate the problem we recommend using ``multiprocessing.Process``,
# which *forks* from the main process and allows sharing the same graph
# object to trainer processes via *copy-on-write*. This can greatly reduce
# the memory consumption.
#
# Normally, DGL maintains only one sparse matrix representation (usually COO)
# for each graph, and will create new formats when some APIs are called for
# efficiency. For instance, calling ``in_degrees`` will create a CSC
# representation for the graph, and calling ``out_degrees`` will create a
# CSR representation. A consequence is that if a graph is shared to
# trainer processes via copy-on-write *before* having its CSC/CSR
# created, each trainer will create its own CSC/CSR replica once ``in_degrees``
# or ``out_degrees`` is called. To avoid this, you need to create
# all sparse matrix representations beforehand using the ``create_formats_``
# method:
#
graph.create_formats_()
######################################################################
# Then you can spawn the subprocesses to train with multiple GPUs.
#
# .. note::
#
# You will need to use ``dgl.multiprocessing`` instead of the Python
# ``multiprocessing`` package. ``dgl.multiprocessing`` is identical to
# Python’s built-in ``multiprocessing`` except that it handles the
# subtleties between forking and multithreading in Python.
#
# Say you have four GPUs.
num_gpus = 4
import dgl.multiprocessing as mp
devices = list(range(num_gpus))
procs = []
for proc_id in range(num_gpus):
p = mp.Process(target=run, args=(proc_id, devices))
p.start()
procs.append(p)
for p in procs:
p.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment