"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "544ba677dd97a49c8124208837025aa8b5ab639e"
Unverified Commit 283b2cfb authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Distributed] Fix issues of GraphSage to support running on GPU (#2084)



* fix issues on GPU

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMa <manchao@38f9d3587685.ant.amazon.com>
parent b57bedb7
......@@ -78,12 +78,27 @@ To run unsupervised training:
python3 ~/dgl/tools/launch.py \
--workspace ~/graphsage/ \
--num_trainers 1 \
--num_samplers 4 \
--num_servers 1 \
--part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \
"python3 dgl_code/train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 3 --batch_size 1000 --num_workers 4"
```
By default, this code will run on CPU. If you have GPU support, you can just add a `--num_gpus` argument in user command:
```bash
python3 ~/dgl/tools/launch.py \
--workspace ~/graphsage/ \
--num_trainers 4 \
--num_samplers 4 \
--num_servers 1 \
--part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \
"python3 dgl_code/train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 3 --batch_size 1000"
"python3 dgl_code/train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 30 --batch_size 1000 --num_workers 4 --num_gpus 4"
```
## Distributed code runs in the standalone mode
The standalone mode is mainly used for development and testing. The procedure to run the code is much simpler.
......
......@@ -52,7 +52,7 @@ class NeighborSampler(object):
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, self.device)
batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu")
blocks[0].srcdata['features'] = batch_inputs
blocks[-1].dstdata['labels'] = batch_labels
return blocks
......@@ -115,7 +115,7 @@ class DistSAGE(nn.Module):
drop_last=False)
for blocks in tqdm.tqdm(dataloader):
block = blocks[0]
block = blocks[0].to(device)
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device)
......@@ -173,7 +173,11 @@ def run(args, device, data):
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model)
if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model)
else:
dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
......@@ -211,6 +215,8 @@ def run(args, device, data):
num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
blocks = [block.to(device) for block in blocks]
batch_labels = batch_labels.to(device)
# Compute loss and prediction
start = time.time()
batch_pred = model(blocks, batch_inputs)
......@@ -275,7 +281,10 @@ def main(args):
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
device = th.device('cpu')
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes)
......@@ -296,8 +305,8 @@ if __name__ == '__main__':
parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--num_servers', type=int, default=1, help='The number of servers')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20)
parser.add_argument('--num_hidden', type=int, default=16)
parser.add_argument('--num_layers', type=int, default=2)
......
......@@ -150,6 +150,8 @@ class NeighborSampler(object):
blocks.insert(0, block)
input_nodes = blocks[0].srcdata[dgl.NID]
blocks[0].srcdata['features'] = load_subtensor(self.g, input_nodes, 'cpu')
# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks
......@@ -214,7 +216,7 @@ class DistSAGE(SAGE):
num_workers=args.num_workers)
for blocks in tqdm.tqdm(dataloader):
block = blocks[0]
block = blocks[0].to(device)
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device)
......@@ -306,19 +308,22 @@ def run(args, device, data):
dgl.distributed.sample_neighbors, args.num_negs, args.remove_edge)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataloader = dgl.distributed.DistDataLoader(
dataset=train_eids.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
drop_last=False)
# Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
model = model.to(device)
if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model)
if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model)
else:
dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
......@@ -352,12 +357,14 @@ def run(args, device, data):
tic_step = time.time()
sample_t.append(tic_step - start)
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
blocks = [block.to(device) for block in blocks]
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
# Load the input features as well as output labels
batch_inputs = load_subtensor(g, input_nodes, device)
batch_inputs = blocks[0].srcdata['features']
copy_time = time.time()
feat_copy_t.append(copy_time - tic_step)
......@@ -431,7 +438,10 @@ def main(args):
global_valid_nid = th.LongTensor(np.nonzero(g.ndata['val_mask'][np.arange(g.number_of_nodes())]))
global_test_nid = th.LongTensor(np.nonzero(g.ndata['test_mask'][np.arange(g.number_of_nodes())]))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
device = th.device('cpu')
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
# Pack data
in_feats = g.ndata['features'].shape[1]
......@@ -454,8 +464,8 @@ if __name__ == '__main__':
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_servers', type=int, default=1, help='Server count on each machine.')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20)
parser.add_argument('--num_hidden', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment