Unverified Commit 2a107320 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[User Guide] Update distributed training code (#2114)

* check num_workers

* update

* update

* update

* update

* update

* update

* update user guide
parent 4327d712
......@@ -22,17 +22,24 @@ are the same as :ref:`mini-batch training <guide-minibatch>`.
import dgl
import torch as th
dgl.distributed.initialize(ip_config, num_workers=num_workers)
dgl.distributed.initialize('ip_config.txt', num_servers, num_workers)
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph('ip_config.txt', 'graph_name')
train_nid = dgl.distributed.node_split(g.ndata['train_mask'])
g = dgl.distributed.DistGraph('graph_name', 'part_config.json')
pb = g.get_partition_book()
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 25])
train_dataloader = dgl.dataloading.NodeDataLoader(g, train_nid,
sampler, batch_size=1024,
shuffle=True,drop_last=False)
sampler = NeighborSampler(g, [10,25],
dgl.distributed.sample_neighbors,
device)
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False)
# Define model and optimizer
model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment