Unverified Commit 7c7113f6 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

add use_ddp to dataloaders (#2911)

parent b03077b6
......@@ -264,16 +264,26 @@ def run(proc_id, n_gpus, args, devices, data):
neighbor_samples = args.neighbor_samples
num_workers = args.workers
train_pairs = torch.split(
torch.tensor(train_pairs), math.ceil(len(train_pairs) / n_gpus)
)[proc_id]
neighbor_sampler = NeighborSampler(g, [neighbor_samples])
if n_gpus > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_pairs, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_pairs,
batch_size=batch_size,
collate_fn=neighbor_sampler.sample,
num_workers=num_workers,
sampler=train_sampler,
pin_memory=True,
)
else:
train_dataloader = torch.utils.data.DataLoader(
train_pairs,
batch_size=batch_size,
collate_fn=neighbor_sampler.sample,
shuffle=True,
num_workers=num_workers,
shuffle=True,
drop_last=False,
pin_memory=True,
)
......@@ -333,6 +343,8 @@ def run(proc_id, n_gpus, args, devices, data):
start = time.time()
for epoch in range(epochs):
if n_gpus > 1:
train_sampler.set_epoch(epoch)
model.train()
data_iter = train_dataloader
......
......@@ -182,6 +182,17 @@ def config():
def run(proc_id, n_gpus, args, devices, dataset):
dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
if n_gpus > 0:
th.cuda.set_device(dev_id)
train_labels = dataset.train_labels
train_truths = dataset.train_truths
num_edges = train_truths.shape[0]
......@@ -196,6 +207,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k)))
for k in dataset.possible_rating_values},
sampler,
use_ddp=n_gpus > 1,
batch_size=args.minibatch_size,
shuffle=True,
drop_last=False)
......@@ -218,17 +230,6 @@ def run(proc_id, n_gpus, args, devices, dataset):
shuffle=False,
drop_last=False)
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
if n_gpus > 0:
th.cuda.set_device(dev_id)
nd_possible_rating_values = \
th.FloatTensor(dataset.possible_rating_values)
nd_possible_rating_values = nd_possible_rating_values.to(dev_id)
......@@ -254,6 +255,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
iter_idx = 1
for epoch in range(1, args.train_max_epoch):
if n_gpus > 1:
dataloader.set_epoch(epoch)
if epoch > 1:
t0 = time.time()
net.train()
......@@ -340,6 +343,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
if n_gpus > 1:
th.distributed.barrier()
if proc_id == 0:
print(logging_str)
if proc_id == 0:
print('Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.format(
......
......@@ -234,13 +234,19 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
# Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
# Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks
if n_gpus > 1:
dist_sampler = torch.utils.data.distributed.DistributedSampler(train_nid.numpy(), shuffle=True, drop_last=False)
dataloader = DataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
sampler=dist_sampler,
num_workers=args.num_workers_per_gpu)
else:
dataloader = DataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
......@@ -274,6 +280,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
if n_gpus > 1:
dist_sampler.set_epoch(epoch)
tic = time.time()
model.train()
for step, (blocks, hist_blocks) in enumerate(dataloader):
......
......@@ -85,9 +85,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze()
# Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
train_nid = train_nid[:n_gpus * args.batch_size + 1]
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
......@@ -96,6 +94,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_g,
train_nid,
sampler,
use_ddp=n_gpus > 1,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
......@@ -113,6 +112,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
if n_gpus > 1:
dataloader.set_epoch(epoch)
tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
......
......@@ -76,12 +76,6 @@ def run(proc_id, n_gpus, args, devices, data):
# Create PyTorch DataLoader for constructing blocks
n_edges = g.num_edges()
train_seeds = np.arange(n_edges)
if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu :
(proc_id + 1) * num_per_gpu \
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]]
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(
......@@ -93,6 +87,7 @@ def run(proc_id, n_gpus, args, devices, data):
th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)]),
negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
use_ddp=n_gpus > 1,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
......@@ -116,6 +111,8 @@ def run(proc_id, n_gpus, args, devices, data):
best_eval_acc = 0
best_test_acc = 0
for epoch in range(args.num_epochs):
if n_gpus > 1:
dataloader.set_epoch(epoch)
tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
......
......@@ -2,6 +2,7 @@
import inspect
import torch as th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
from ...distributed import DistGraph
from ...distributed import DistDataLoader
......@@ -272,6 +273,12 @@ class NodeDataLoader:
device : device context, optional
The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
......@@ -288,6 +295,21 @@ class NodeDataLoader:
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
on the `use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler, use_ddp=True,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes
-----
Please refer to
......@@ -296,7 +318,7 @@ class NodeDataLoader:
"""
collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
def __init__(self, g, nids, block_sampler, device='cpu', use_ddp=False, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -347,10 +369,21 @@ class NodeDataLoader:
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(
dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
self.is_distributed = False
# Precompute the CSR and CSC representations so each subprocess does not
......@@ -371,6 +404,24 @@ class NodeDataLoader:
"""Return the number of batches of the data loader."""
return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class EdgeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch for
......@@ -442,6 +493,15 @@ class EdgeDataLoader:
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
The dataloader will have a :attr:`dist_sampler` attribute to set the
epoch number, as recommended by PyTorch.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
......@@ -524,6 +584,22 @@ class EdgeDataLoader:
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
See also
--------
dgl.dataloading.dataloader.EdgeCollator
......@@ -544,7 +620,7 @@ class EdgeDataLoader:
"""
collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
def __init__(self, g, eids, block_sampler, device='cpu', use_ddp=False, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -553,12 +629,27 @@ class EdgeDataLoader:
else:
dataloader_kwargs[k] = v
self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
dataset = self.collator.dataset
assert not isinstance(g, DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \
+ 'Please use DistDataLoader directly.'
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
self.device = device
# Precompute the CSR and CSC representations so each subprocess does not
......@@ -574,6 +665,24 @@ class EdgeDataLoader:
"""Return the number of batches of the data loader."""
return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class GraphDataLoader:
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch.
......@@ -595,10 +704,23 @@ class GraphDataLoader:
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, **kwargs):
def __init__(self, dataset, collate_fn=None, use_ddp=False, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -612,6 +734,16 @@ class GraphDataLoader:
else:
self.collate = collate_fn
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(dataset=dataset,
collate_fn=self.collate,
**dataloader_kwargs)
......@@ -623,3 +755,21 @@ class GraphDataLoader:
def __len__(self):
"""Return the number of batches of the data loader."""
return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment