"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a042909c836c794a34508b314b3ce8ce93a96284"
Unverified Commit bb43d042 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Benchmark] relocate distributed node classification (#5681)

parent d0fa2909
## Distributed training
This is an example of training GraphSage in a distributed fashion. Before training, please install some python libs by pip:
```
pip3 install ogb
```
**Requires PyTorch 1.12.0+ to work.**
To train GraphSage, it has five steps:
### Step 0: Setup a Distributed File System
* You may skip this step if your cluster already has folder(s) synchronized across machines.
To perform distributed training, files and codes need to be accessed across multiple machines. A distributed file system would perfectly handle the job (i.e., NFS, Ceph).
#### Server side setup
Here is an example of how to setup NFS. First, install essential libs on the storage server
```
sudo apt-get install nfs-kernel-server
```
Below we assume the user account is `ubuntu` and we create a directory of `workspace` in the home directory.
```
mkdir -p /home/ubuntu/workspace
```
We assume that the all servers are under a subnet with ip range `192.168.0.0` to `192.168.255.255`. The exports configuration needs to be modifed to
```
sudo vim /etc/exports
# add the following line
/home/ubuntu/workspace 192.168.0.0/16(rw,sync,no_subtree_check)
```
The server's internal ip can be checked via `ifconfig` or `ip`. If the ip does not begin with `192.168`, then you may use
```
/home/ubuntu/workspace 10.0.0.0/8(rw,sync,no_subtree_check)
/home/ubuntu/workspace 172.16.0.0/12(rw,sync,no_subtree_check)
```
Then restart NFS, the setup on server side is finished.
```
sudo systemctl restart nfs-kernel-server
```
For configraution details, please refer to [NFS ArchWiki](https://wiki.archlinux.org/index.php/NFS).
#### Client side setup
To use NFS, clients also require to install essential packages
```
sudo apt-get install nfs-common
```
You can either mount the NFS manually
```
mkdir -p /home/ubuntu/workspace
sudo mount -t nfs <nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace
```
or edit the fstab so the folder will be mounted automatically
```
# vim /etc/fstab
## append the following line to the file
<nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace nfs defaults 0 0
```
Then run `mount -a`.
Now go to `/home/ubuntu/workspace` and clone the DGL Github repository.
### Step 1: set IP configuration file.
User need to set their own IP configuration file `ip_config.txt` before training. For example, if we have four machines in current cluster, the IP configuration
could like this:
```
172.31.19.1
172.31.23.205
172.31.29.175
172.31.16.98
```
Users need to make sure that the master node (node-0) has right permission to ssh to all the other nodes without password authentication.
[This link](https://linuxize.com/post/how-to-setup-passwordless-ssh-login/) provides instructions of setting passwordless SSH login.
### Step 2: partition the graph.
The example provides a script to partition some builtin graphs such as Reddit and OGB product graph.
If we want to train GraphSage on 4 machines, we need to partition the graph into 4 parts.
In this example, we partition the ogbn-products graph into 4 parts with Metis on node-0. The partitions are balanced with respect to
the number of nodes, the number of edges and the number of labelled nodes.
```
python3 partition_graph.py --dataset ogbn-products --num_parts 4 --balance_train --balance_edges
```
This script generates partitioned graphs and store them in the directory called `data`.
### Step 3: Launch distributed jobs
DGL provides a script to launch the training job in the cluster. `part_config` and `ip_config`
specify relative paths to the path of the workspace.
The command below launches one process per machine for both sampling and training.
```
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 1 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogbn-products.json \
--ip_config ip_config.txt \
"python3 node_classification.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 30 --batch_size 1000"
```
By default, this code will run on CPU. If you have GPU support, you can just add a `--num_gpus` argument in user command:
```
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogbn-products.json \
--ip_config ip_config.txt \
"python3 node_classification.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 30 --batch_size 1000 --num_gpus 4"
```
import argparse
import socket
import time
import dgl
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = (
g.ndata["features"][input_nodes].to(device) if load_feat else None
)
batch_labels = g.ndata["labels"][seeds].to(device)
return batch_inputs, batch_labels
class DistSAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for _ in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for i, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if i != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without
neighbor sampling).
g : the entire graph.
x : the input of entire node set.
Distributed layer-wise inference.
"""
# During inference with sampling, multi-layer blocks are very
# inefficient because lots of computations in the first few layers
# are repeated. Therefore, we compute the representation of all nodes
# layer by layer. The nodes on each layer are of course splitted in
# batches.
# TODO: can we standardize this?
nodes = dgl.distributed.node_split(
np.arange(g.num_nodes()),
g.get_partition_book(),
force_even=True,
)
y = dgl.distributed.DistTensor(
(g.num_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)
for i, layer in enumerate(self.layers):
if i == len(self.layers) - 1:
y = dgl.distributed.DistTensor(
(g.num_nodes(), self.n_classes),
th.float32,
"h_last",
persistent=True,
)
print(f"|V|={g.num_nodes()}, eval batch size: {batch_size}")
sampler = dgl.dataloading.NeighborSampler([-1])
dataloader = dgl.dataloading.DistNodeDataLoader(
g,
nodes,
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = x[input_nodes].to(device)
h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
if i != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h.cpu()
x = y
g.barrier()
return y
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : the node Ids for validation.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
pred[test_nid], labels[test_nid]
)
def run(args, device, data):
train_nid, val_nid, test_nid, in_feats, n_classes, g = data
sampler = dgl.dataloading.NeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")]
)
dataloader = dgl.dataloading.DistNodeDataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
)
model = DistSAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device)
if args.num_gpus == 0:
model = th.nn.parallel.DistributedDataParallel(model)
else:
model = th.nn.parallel.DistributedDataParallel(
model, device_ids=[device], output_device=device
)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Training loop.
iter_tput = []
epoch = 0
epoch_time = []
test_acc = 0.0
for _ in range(args.num_epochs):
epoch += 1
tic = time.time()
sample_time = 0
forward_time = 0
backward_time = 0
update_time = 0
num_seeds = 0
num_inputs = 0
start = time.time()
# Loop over the dataloader to sample the computation dependency graph
# as a list of blocks.
step_time = []
with model.join():
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time()
sample_time += tic_step - start
batch_inputs, batch_labels = load_subtensor(
g, seeds, input_nodes, "cpu"
)
batch_labels = batch_labels.long()
num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
# Move to target device.
blocks = [block.to(device) for block in blocks]
batch_inputs = batch_inputs.to(device)
batch_labels = batch_labels.to(device)
# Compute loss and prediction.
start = time.time()
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
forward_end = time.time()
optimizer.zero_grad()
loss.backward()
compute_end = time.time()
forward_time += forward_end - start
backward_time += compute_end - forward_end
optimizer.step()
update_time += time.time() - compute_end
step_t = time.time() - tic_step
step_time.append(step_t)
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if (step + 1) % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = (
th.cuda.max_memory_allocated() / 1000000
if th.cuda.is_available()
else 0
)
print(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
"{:.1f} MB | time {:.3f} s".format(
g.rank(),
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
np.mean(step_time[-args.log_every :]),
)
)
start = time.time()
toc = time.time()
print(
"Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, "
"forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
"#inputs: {}".format(
g.rank(),
toc - tic,
sample_time,
forward_time,
backward_time,
update_time,
num_seeds,
num_inputs,
)
)
epoch_time.append(toc - tic)
if epoch % args.eval_every == 0 or epoch == args.num_epochs:
start = time.time()
val_acc, test_acc = evaluate(
model.module,
g,
g.ndata["features"],
g.ndata["labels"],
val_nid,
test_nid,
args.batch_size_eval,
device,
)
print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
g.rank(), val_acc, test_acc, time.time() - start
)
)
return np.mean(epoch_time[-int(args.num_epochs * 0.8) :]), test_acc
def main(args):
print(socket.gethostname(), "Initializing DistDGL.")
dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
print(socket.gethostname(), "Initializing PyTorch process group.")
th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), "Initializing DistGraph.")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print(socket.gethostname(), "rank:", g.rank())
pb = g.get_partition_book()
if "trainer_id" in g.ndata:
train_nid = dgl.distributed.node_split(
g.ndata["train_mask"],
pb,
force_even=True,
node_trainer_ids=g.ndata["trainer_id"],
)
val_nid = dgl.distributed.node_split(
g.ndata["val_mask"],
pb,
force_even=True,
node_trainer_ids=g.ndata["trainer_id"],
)
test_nid = dgl.distributed.node_split(
g.ndata["test_mask"],
pb,
force_even=True,
node_trainer_ids=g.ndata["trainer_id"],
)
else:
train_nid = dgl.distributed.node_split(
g.ndata["train_mask"], pb, force_even=True
)
val_nid = dgl.distributed.node_split(
g.ndata["val_mask"], pb, force_even=True
)
test_nid = dgl.distributed.node_split(
g.ndata["test_mask"], pb, force_even=True
)
local_nid = pb.partid2nids(pb.partid).detach().numpy()
print(
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
"(local: {})".format(
g.rank(),
len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
)
del local_nid
if args.num_gpus == 0:
device = th.device("cpu")
else:
dev_id = g.rank() % args.num_gpus
device = th.device("cuda:" + str(dev_id))
n_classes = args.n_classes
if n_classes == 0:
labels = g.ndata["labels"][np.arange(g.num_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
del labels
print(f"Number of classes: {n_classes}")
# Pack data.
in_feats = g.ndata["features"].shape[1]
data = train_nid, val_nid, test_nid, in_feats, n_classes, g
# Train and evaluate.
epoch_time, test_acc = run(args, device, data)
print(
f"Summary of node classification(GraphSAGE): GraphName "
f"{args.graph_name} | TrainEpochTime(mean) {epoch_time:.4f} "
f"| TestAccuracy {test_acc:.4f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Distributed GraphSAGE.")
parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument(
"--ip_config", type=str, help="The file for IP configuration"
)
parser.add_argument(
"--part_config", type=str, help="The path to the partition config file"
)
parser.add_argument(
"--n_classes", type=int, default=0, help="the number of classes"
)
parser.add_argument(
"--backend",
type=str,
default="gloo",
help="pytorch distributed backend",
)
parser.add_argument(
"--num_gpus",
type=int,
default=0,
help="the number of GPU device. Use 0 for CPU training",
)
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--num_hidden", type=int, default=16)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--fan_out", type=str, default="10,25")
parser.add_argument("--batch_size", type=int, default=1000)
parser.add_argument("--batch_size_eval", type=int, default=100000)
parser.add_argument("--log_every", type=int, default=20)
parser.add_argument("--eval_every", type=int, default=5)
parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument(
"--local_rank", type=int, help="get rank of the process"
)
parser.add_argument(
"--pad-data",
default=False,
action="store_true",
help="Pad train nid to the same length across machine, to ensure num "
"of batches to be the same.",
)
parser.add_argument(
"--net_type",
type=str,
default="socket",
help="backend net type, 'socket' or 'tensorpipe'",
)
args = parser.parse_args()
print(f"Arguments: {args}")
main(args)
import argparse
import time
import dgl
import torch as th
from dgl.data import RedditDataset
from ogb.nodeproppred import DglNodePropPredDataset
def load_reddit(self_loop=True):
"""Load reddit dataset."""
data = RedditDataset(self_loop=self_loop)
g = data[0]
g.ndata["features"] = g.ndata.pop("feat")
g.ndata["labels"] = g.ndata.pop("label")
return g, data.num_classes
def load_ogb(name, root="dataset"):
"""Load ogbn dataset."""
data = DglNodePropPredDataset(name=name, root=root)
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata["features"] = graph.ndata.pop("feat")
graph.ndata["labels"] = labels
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
train_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
graph.ndata["train_mask"] = train_mask
graph.ndata["val_mask"] = val_mask
graph.ndata["test_mask"] = test_mask
return graph, num_labels
if __name__ == "__main__":
argparser = argparse.ArgumentParser("Partition graph")
argparser.add_argument(
"--dataset",
type=str,
default="reddit",
help="datasets: reddit, ogbn-products, ogbn-papers100M",
)
argparser.add_argument(
"--num_parts", type=int, default=4, help="number of partitions"
)
argparser.add_argument(
"--part_method", type=str, default="metis", help="the partition method"
)
argparser.add_argument(
"--balance_train",
action="store_true",
help="balance the training size in each partition.",
)
argparser.add_argument(
"--undirected",
action="store_true",
help="turn the graph into an undirected graph.",
)
argparser.add_argument(
"--balance_edges",
action="store_true",
help="balance the number of edges in each partition.",
)
argparser.add_argument(
"--num_trainers_per_machine",
type=int,
default=1,
help="the number of trainers per machine. The trainer ids are stored\
in the node feature 'trainer_id'",
)
argparser.add_argument(
"--output",
type=str,
default="data",
help="Output path of partitioned graph.",
)
args = argparser.parse_args()
start = time.time()
if args.dataset == "reddit":
g, _ = load_reddit()
elif args.dataset in ["ogbn-products", "ogbn-papers100M"]:
g, _ = load_ogb(args.dataset)
else:
raise RuntimeError(f"Unknown dataset: {args.dataset}")
print(
"Load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
)
print("|V|={}, |E|={}".format(g.num_nodes(), g.num_edges()))
print(
"train: {}, valid: {}, test: {}".format(
th.sum(g.ndata["train_mask"]),
th.sum(g.ndata["val_mask"]),
th.sum(g.ndata["test_mask"]),
)
)
if args.balance_train:
balance_ntypes = g.ndata["train_mask"]
else:
balance_ntypes = None
if args.undirected:
sym_g = dgl.to_bidirected(g, readonly=True)
for key in g.ndata:
sym_g.ndata[key] = g.ndata[key]
g = sym_g
dgl.distributed.partition_graph(
g,
args.dataset,
args.num_parts,
args.output,
part_method=args.part_method,
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges,
num_trainers_per_machine=args.num_trainers_per_machine,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment