Unverified Commit 7c7113f6 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

add use_ddp to dataloaders (#2911)

parent b03077b6
...@@ -264,18 +264,28 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -264,18 +264,28 @@ def run(proc_id, n_gpus, args, devices, data):
neighbor_samples = args.neighbor_samples neighbor_samples = args.neighbor_samples
num_workers = args.workers num_workers = args.workers
train_pairs = torch.split(
torch.tensor(train_pairs), math.ceil(len(train_pairs) / n_gpus)
)[proc_id]
neighbor_sampler = NeighborSampler(g, [neighbor_samples]) neighbor_sampler = NeighborSampler(g, [neighbor_samples])
train_dataloader = torch.utils.data.DataLoader( if n_gpus > 1:
train_pairs, train_sampler = torch.utils.data.distributed.DistributedSampler(
batch_size=batch_size, train_pairs, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False)
collate_fn=neighbor_sampler.sample, train_dataloader = torch.utils.data.DataLoader(
shuffle=True, train_pairs,
num_workers=num_workers, batch_size=batch_size,
pin_memory=True, collate_fn=neighbor_sampler.sample,
) num_workers=num_workers,
sampler=train_sampler,
pin_memory=True,
)
else:
train_dataloader = torch.utils.data.DataLoader(
train_pairs,
batch_size=batch_size,
collate_fn=neighbor_sampler.sample,
num_workers=num_workers,
shuffle=True,
drop_last=False,
pin_memory=True,
)
model = DGLGATNE( model = DGLGATNE(
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a, num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a,
...@@ -333,6 +343,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -333,6 +343,8 @@ def run(proc_id, n_gpus, args, devices, data):
start = time.time() start = time.time()
for epoch in range(epochs): for epoch in range(epochs):
if n_gpus > 1:
train_sampler.set_epoch(epoch)
model.train() model.train()
data_iter = train_dataloader data_iter = train_dataloader
......
...@@ -182,6 +182,17 @@ def config(): ...@@ -182,6 +182,17 @@ def config():
def run(proc_id, n_gpus, args, devices, dataset): def run(proc_id, n_gpus, args, devices, dataset):
dev_id = devices[proc_id] dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
if n_gpus > 0:
th.cuda.set_device(dev_id)
train_labels = dataset.train_labels train_labels = dataset.train_labels
train_truths = dataset.train_truths train_truths = dataset.train_truths
num_edges = train_truths.shape[0] num_edges = train_truths.shape[0]
...@@ -196,6 +207,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -196,6 +207,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k))) dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k)))
for k in dataset.possible_rating_values}, for k in dataset.possible_rating_values},
sampler, sampler,
use_ddp=n_gpus > 1,
batch_size=args.minibatch_size, batch_size=args.minibatch_size,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False)
...@@ -218,17 +230,6 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -218,17 +230,6 @@ def run(proc_id, n_gpus, args, devices, dataset):
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
if n_gpus > 0:
th.cuda.set_device(dev_id)
nd_possible_rating_values = \ nd_possible_rating_values = \
th.FloatTensor(dataset.possible_rating_values) th.FloatTensor(dataset.possible_rating_values)
nd_possible_rating_values = nd_possible_rating_values.to(dev_id) nd_possible_rating_values = nd_possible_rating_values.to(dev_id)
...@@ -254,6 +255,8 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -254,6 +255,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
iter_idx = 1 iter_idx = 1
for epoch in range(1, args.train_max_epoch): for epoch in range(1, args.train_max_epoch):
if n_gpus > 1:
dataloader.set_epoch(epoch)
if epoch > 1: if epoch > 1:
t0 = time.time() t0 = time.time()
net.train() net.train()
...@@ -340,7 +343,8 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -340,7 +343,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
if n_gpus > 1: if n_gpus > 1:
th.distributed.barrier() th.distributed.barrier()
print(logging_str) if proc_id == 0:
print(logging_str)
if proc_id == 0: if proc_id == 0:
print('Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.format( print('Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.format(
best_epoch, best_valid_rmse, best_test_rmse)) best_epoch, best_valid_rmse, best_test_rmse))
......
...@@ -234,20 +234,26 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -234,20 +234,26 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = train_mask.nonzero().squeeze() train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze() val_nid = val_mask.nonzero().squeeze()
# Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( if n_gpus > 1:
dataset=train_nid.numpy(), dist_sampler = torch.utils.data.distributed.DistributedSampler(train_nid.numpy(), shuffle=True, drop_last=False)
batch_size=args.batch_size, dataloader = DataLoader(
collate_fn=sampler.sample_blocks, dataset=train_nid.numpy(),
shuffle=True, batch_size=args.batch_size,
drop_last=False, collate_fn=sampler.sample_blocks,
num_workers=args.num_workers_per_gpu) sampler=dist_sampler,
num_workers=args.num_workers_per_gpu)
else:
dataloader = DataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
# Define model # Define model
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu) model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)
...@@ -274,6 +280,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -274,6 +280,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg = 0 avg = 0
iter_tput = [] iter_tput = []
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
if n_gpus > 1:
dist_sampler.set_epoch(epoch)
tic = time.time() tic = time.time()
model.train() model.train()
for step, (blocks, hist_blocks) in enumerate(dataloader): for step, (blocks, hist_blocks) in enumerate(dataloader):
......
...@@ -85,9 +85,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -85,9 +85,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = train_mask.nonzero().squeeze() train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze() val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze() test_nid = test_mask.nonzero().squeeze()
train_nid = train_nid[:n_gpus * args.batch_size + 1]
# Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
...@@ -96,6 +94,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -96,6 +94,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_g, train_g,
train_nid, train_nid,
sampler, sampler,
use_ddp=n_gpus > 1,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
...@@ -113,6 +112,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -113,6 +112,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg = 0 avg = 0
iter_tput = [] iter_tput = []
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
if n_gpus > 1:
dataloader.set_epoch(epoch)
tic = time.time() tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
......
...@@ -76,12 +76,6 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -76,12 +76,6 @@ def run(proc_id, n_gpus, args, devices, data):
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
n_edges = g.num_edges() n_edges = g.num_edges()
train_seeds = np.arange(n_edges) train_seeds = np.arange(n_edges)
if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu :
(proc_id + 1) * num_per_gpu \
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]]
# Create sampler # Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
...@@ -93,6 +87,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -93,6 +87,7 @@ def run(proc_id, n_gpus, args, devices, data):
th.arange(n_edges // 2, n_edges), th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)]), th.arange(0, n_edges // 2)]),
negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share), negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
use_ddp=n_gpus > 1,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
...@@ -116,6 +111,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -116,6 +111,8 @@ def run(proc_id, n_gpus, args, devices, data):
best_eval_acc = 0 best_eval_acc = 0
best_test_acc = 0 best_test_acc = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
if n_gpus > 1:
dataloader.set_epoch(epoch)
tic = time.time() tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import inspect import inspect
import torch as th import torch as th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
from ...distributed import DistGraph from ...distributed import DistGraph
from ...distributed import DistDataLoader from ...distributed import DistDataLoader
...@@ -272,6 +273,12 @@ class NodeDataLoader: ...@@ -272,6 +273,12 @@ class NodeDataLoader:
device : device context, optional device : device context, optional
The device of the generated MFGs in each iteration, which should be a The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``). PyTorch device object (e.g., ``torch.device``).
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -288,6 +295,21 @@ class NodeDataLoader: ...@@ -288,6 +295,21 @@ class NodeDataLoader:
>>> for input_nodes, output_nodes, blocks in dataloader: >>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks) ... train_on(input_nodes, output_nodes, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
on the `use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler, use_ddp=True,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes Notes
----- -----
Please refer to Please refer to
...@@ -296,7 +318,7 @@ class NodeDataLoader: ...@@ -296,7 +318,7 @@ class NodeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(NodeCollator).args collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, block_sampler, device='cpu', **kwargs): def __init__(self, g, nids, block_sampler, device='cpu', use_ddp=False, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -347,10 +369,21 @@ class NodeDataLoader: ...@@ -347,10 +369,21 @@ class NodeDataLoader:
dataloader_kwargs['shuffle'] = False dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False dataloader_kwargs['drop_last'] = False
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader( self.dataloader = DataLoader(
dataset, dataset,
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs)
self.is_distributed = False self.is_distributed = False
# Precompute the CSR and CSC representations so each subprocess does not # Precompute the CSR and CSC representations so each subprocess does not
...@@ -371,6 +404,24 @@ class NodeDataLoader: ...@@ -371,6 +404,24 @@ class NodeDataLoader:
"""Return the number of batches of the data loader.""" """Return the number of batches of the data loader."""
return len(self.dataloader) return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class EdgeDataLoader: class EdgeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list """PyTorch dataloader for batch-iterating over a set of edges, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch for of message flow graphs (MFGs) as computation dependency of the said minibatch for
...@@ -442,6 +493,15 @@ class EdgeDataLoader: ...@@ -442,6 +493,15 @@ class EdgeDataLoader:
See the description of the argument with the same name in the docstring of See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details. :class:`~dgl.dataloading.EdgeCollator` for more details.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
The dataloader will have a :attr:`dist_sampler` attribute to set the
epoch number, as recommended by PyTorch.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -524,6 +584,22 @@ class EdgeDataLoader: ...@@ -524,6 +584,22 @@ class EdgeDataLoader:
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
See also See also
-------- --------
dgl.dataloading.dataloader.EdgeCollator dgl.dataloading.dataloader.EdgeCollator
...@@ -544,7 +620,7 @@ class EdgeDataLoader: ...@@ -544,7 +620,7 @@ class EdgeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(EdgeCollator).args collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, block_sampler, device='cpu', **kwargs): def __init__(self, g, eids, block_sampler, device='cpu', use_ddp=False, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -553,12 +629,27 @@ class EdgeDataLoader: ...@@ -553,12 +629,27 @@ class EdgeDataLoader:
else: else:
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs) self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
dataset = self.collator.dataset
assert not isinstance(g, DistGraph), \ assert not isinstance(g, DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \ 'EdgeDataLoader does not support DistGraph for now. ' \
+ 'Please use DistDataLoader directly.' + 'Please use DistDataLoader directly.'
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader( self.dataloader = DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
self.device = device self.device = device
# Precompute the CSR and CSC representations so each subprocess does not # Precompute the CSR and CSC representations so each subprocess does not
...@@ -574,6 +665,24 @@ class EdgeDataLoader: ...@@ -574,6 +665,24 @@ class EdgeDataLoader:
"""Return the number of batches of the data loader.""" """Return the number of batches of the data loader."""
return len(self.dataloader) return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class GraphDataLoader: class GraphDataLoader:
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch. graph and corresponding label tensor (if provided) of the said minibatch.
...@@ -595,10 +704,23 @@ class GraphDataLoader: ...@@ -595,10 +704,23 @@ class GraphDataLoader:
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader: >>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels) ... train_on(batched_graph, labels)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
""" """
collator_arglist = inspect.getfullargspec(GraphCollator).args collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, **kwargs): def __init__(self, dataset, collate_fn=None, use_ddp=False, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -612,6 +734,16 @@ class GraphDataLoader: ...@@ -612,6 +734,16 @@ class GraphDataLoader:
else: else:
self.collate = collate_fn self.collate = collate_fn
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = DistributedSampler(
dataset,
shuffle=dataloader_kwargs['shuffle'],
drop_last=dataloader_kwargs['drop_last'])
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
dataloader_kwargs['sampler'] = self.dist_sampler
self.dataloader = DataLoader(dataset=dataset, self.dataloader = DataLoader(dataset=dataset,
collate_fn=self.collate, collate_fn=self.collate,
**dataloader_kwargs) **dataloader_kwargs)
...@@ -623,3 +755,21 @@ class GraphDataLoader: ...@@ -623,3 +755,21 @@ class GraphDataLoader:
def __len__(self): def __len__(self):
"""Return the number of batches of the data loader.""" """Return the number of batches of the data loader."""
return len(self.dataloader) return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment