Unverified Commit 40a2f3c7 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Refactor RGCN example (#4327)

* Refactor full graph entity classification

* Refactor rgcn with sampling

* README update

* Update

* Results update

* Respect default setting of self_loop=false in entity.py

* Update

* Update README

* Update for multi-gpu

* Update
parent ea6195c2
...@@ -5,42 +5,34 @@ ...@@ -5,42 +5,34 @@
* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction) * Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)
### Dependencies ### Dependencies
* PyTorch 1.10 - rdflib
* rdflib - torchmetrics
* pandas
* tqdm
* TorchMetrics
``` Install as follows:
pip install rdflib pandas ```bash
pip install rdflib
pip install torchmetrics
``` ```
Example code was tested with rdflib 4.2.2 and pandas 0.23.4 How to run
-------
### Entity Classification ### Entity Classification
For AIFB, MUTAG, BGS and AM, Run with the following for entity classification (available datasets: aifb (default), mutag, bgs, and am)
``` ```bash
python entity.py -d aifb --wd 0 --gpu 0 python3 entity.py --dataset aifb
python entity.py -d mutag --n-bases 30 --gpu 0
python entity.py -d bgs --n-bases 40 --gpu 0
python entity.py -d am --n-bases 40 --n-hidden 10 --gpu 0
``` ```
### Entity Classification with minibatch For mini-batch training, run with the following (available datasets are the same as above)
```bash
For AIFB, MUTAG, BGS and AM, python3 entity_sample.py --dataset aifb
``` ```
python entity_sample.py -d aifb --wd 0 --gpu 0 --fanout='20,20' --batch-size 128 For multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma)
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout='-1,-1' --use-self-loop --n-epochs 20 --dropout 0.5 ```bash
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout='-1,-1' --n-epochs=16 --batch-size=16 --dropout 0.3 python3 entity_sample_multi_gpu.py --dataset aifb --gpu 0,1
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout='35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dropout 0.7
``` ```
### Entity Classification on multiple GPUs
To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify
multiple GPU IDs separated by comma, e.g., `--gpu 0,1`.
### Link Prediction ### Link Prediction
FB15k-237 in RAW-MRR FB15k-237 in RAW-MRR
...@@ -51,3 +43,15 @@ FB15k-237 in Filtered-MRR ...@@ -51,3 +43,15 @@ FB15k-237 in Filtered-MRR
``` ```
python link.py --gpu 0 --eval-protocol filtered python link.py --gpu 0 --eval-protocol filtered
``` ```
Summary
-------
### Entity Classification
| Dataset | Full-graph | Mini-batch
| ------------- | ------- | ------
| aifb | ~0.85 | ~0.82
| mutag | ~0.70 | ~0.50
| bgs | ~0.86 | ~0.64
| am | ~0.78 | ~0.42
""" import torch
Differences compared to tkipf/relation-gcn import torch.nn as nn
* weight decay applied to all weights
"""
import argparse
import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.nn.pytorch import RelGraphConv
import argparse
from entity_utils import load_data class RGCN(nn.Module):
from model import RGCN def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__()
def main(args): self.emb = nn.Embedding(num_nodes, h_dim)
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data( # two-layer RGCN
args.dataset, get_norm=True) self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
model = RGCN(g.num_nodes(), self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
args.n_hidden, num_bases=num_rels, self_loop=False)
num_classes,
num_rels, def forward(self, g):
num_bases=args.n_bases) x = self.emb.weight
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
if args.gpu >= 0 and th.cuda.is_available(): h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
device = th.device(args.gpu) return h
else:
device = th.device('cpu') def evaluate(g, target_idx, labels, test_mask, model):
labels = labels.to(device) test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model = model.to(device) model.eval()
g = g.int().to(device) with torch.no_grad():
logits = model(g)
logits = logits[target_idx]
return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd) def train(g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
model.train() model.train()
for epoch in range(100): for epoch in range(50):
logits = model(g) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = loss_fcn(logits[train_idx], labels[train_idx])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item()
train_acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item() print("Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} "
print("Epoch {:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( . format(epoch, loss.item(), acc))
epoch, train_acc, loss.item()))
print()
model.eval()
with th.no_grad():
logits = model(g)
logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
print("Test Accuracy: {:.4f}".format(test_acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification') parser = argparse.ArgumentParser(description='RGCN for entity classification')
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument("--dataset", type=str, default="aifb",
help="number of hidden units") help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
args = parser.parse_args() args = parser.parse_args()
print(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
main(args) print(f'Training with DGL built-in RGCN module.')
# load and preprocess dataset
if args.dataset == 'aifb':
data = AIFBDataset()
elif args.dataset == 'mutag':
data = MUTAGDataset()
elif args.dataset == 'bgs':
data = BGSDataset()
elif args.dataset == 'am':
data = AMDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
g = g.int().to(device)
num_rels = len(g.canonical_etypes)
category = data.predict_category
labels = g.nodes[category].data.pop('labels')
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
# calculate normalization weight for each edge, and find target category and node id
for cetype in g.canonical_etypes:
g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g, edata=['norm'])
node_ids = torch.arange(g.num_nodes()).to(device)
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(g, target_idx, labels, train_mask, model)
acc = evaluate(g, target_idx, labels, test_mask, model)
print("Test accuracy {:.4f}".format(acc))
""" import torch
Differences compared to tkipf/relation-gcn import torch.nn as nn
* weight decay applied to all weights
* remove nodes that won't be touched
"""
import argparse
import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from torchmetrics.functional import accuracy from dgl.nn.pytorch import RelGraphConv
from tqdm import tqdm import argparse
from entity_utils import load_data
from model import RGCN
def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
sampler = MultiLayerNeighborSampler(fanouts)
train_loader = DataLoader(
g,
target_idx[train_idx],
sampler,
use_ddp=use_ddp,
device=device,
batch_size=args.batch_size,
shuffle=True,
drop_last=False)
# The datasets do not have a validation subset, use the train subset
val_loader = DataLoader(
g,
target_idx[train_idx],
sampler,
use_ddp=use_ddp,
device=device,
batch_size=args.batch_size,
shuffle=False,
drop_last=False)
# -1 for sampling all neighbors
test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
test_loader = DataLoader(
g,
target_idx[test_idx],
test_sampler,
use_ddp=use_ddp,
device=device,
batch_size=32,
shuffle=False,
drop_last=False)
return train_loader, val_loader, test_loader
def process_batch(inv_target, batch):
_, seeds, blocks = batch
# map the seed nodes back to their type-specific ids,
# in order to get the target node labels
seeds = inv_target[seeds]
for blc in blocks:
blc.edata['norm'] = dgl.norm_by_dst(blc).unsqueeze(1)
return seeds, blocks
def train(model, train_loader, inv_target,
labels, optimizer):
model.train()
for sample_data in train_loader:
seeds, blocks = process_batch(inv_target, sample_data)
logits = model.forward(blocks)
loss = F.cross_entropy(logits, labels[seeds])
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()
return train_acc, loss.item()
def evaluate(model, eval_loader, inv_target): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
def evaluate(model, label, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
with torch.no_grad():
with th.no_grad(): for input_nodes, output_nodes, blocks in dataloader:
for sample_data in tqdm(eval_loader): output_nodes = inv_target[output_nodes]
seeds, blocks = process_batch(inv_target, sample_data) for block in blocks:
logits = model.forward(blocks) block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits)
eval_logits = th.cat(eval_logits) eval_seeds = torch.cat(eval_seeds)
eval_seeds = th.cat(eval_seeds) return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()
return eval_logits, eval_seeds def train(device, g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer
def main(args): train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = load_data( loss_fcn = nn.CrossEntropyLoss()
args.dataset, inv_target=True) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# construct sampler and dataloader
if args.gpu >= 0 and th.cuda.is_available(): sampler = MultiLayerNeighborSampler([4, 4])
device = th.device(args.gpu) train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
else: batch_size=100, shuffle=True)
device = th.device('cpu') # no separate validation subset, use train index instead for validation
val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
train_loader, val_loader, test_loader = init_dataloaders( batch_size=100, shuffle=False)
args, g, train_idx, test_idx, target_idx, args.gpu) for epoch in range(50):
model.train()
model = RGCN(g.num_nodes(), total_loss = 0
args.n_hidden, for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
num_classes, output_nodes = inv_target[output_nodes]
num_rels, for block in blocks:
num_bases=args.n_bases, block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
dropout=args.dropout, logits = model(blocks)
self_loop=args.use_self_loop, loss = loss_fcn(logits, labels[output_nodes])
ns_mode=True) optimizer.zero_grad()
labels = labels.to(device) loss.backward()
model = model.to(device) optimizer.step()
inv_target = inv_target.to(device) total_loss += loss.item()
acc = evaluate(model, labels, val_loader, inv_target)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd) print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
. format(epoch, total_loss / (it+1), acc))
for epoch in range(args.n_epochs):
train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
epoch, args.n_epochs, train_acc, loss))
val_logits, val_seeds = evaluate(model, val_loader, inv_target)
val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
print("Validation Accuracy: {:.4f}".format(val_acc))
test_logits, test_seeds = evaluate(model, test_loader, inv_target)
test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
print("Final Test Accuracy: {:.4f}".format(test_acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling') parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
parser.add_argument("--dropout", type=float, default=0, parser.add_argument("--dataset", type=str, default="aifb",
help="dropout probability") help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size")
args = parser.parse_args() args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(args) print(f'Training with DGL built-in RGCN module with sampling.')
main(args)
# load and preprocess dataset
if args.dataset == 'aifb':
data = AIFBDataset()
elif args.dataset == 'mutag':
data = MUTAGDataset()
elif args.dataset == 'bgs':
data = BGSDataset()
elif args.dataset == 'am':
data = AMDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
num_rels = len(g.canonical_etypes)
category = data.predict_category
labels = g.nodes[category].data.pop('labels').to(device)
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
# find target category and node id
category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g)
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# rename the fields as they can be changed by DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID)
# find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)
# create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device,
batch_size=32, shuffle=False)
acc = evaluate(model, labels, test_loader, inv_target)
print("Test accuracy {:.4f}".format(acc))
""" import os
Differences compared to tkipf/relation-gcn import torch
* weight decay applied to all weights import torch.nn as nn
"""
import argparse
import gc
import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import torch.multiprocessing as mp
from torchmetrics.functional import accuracy from torchmetrics.functional import accuracy
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv
import argparse
from entity_utils import load_data class RGCN(nn.Module):
from entity_sample import init_dataloaders, train, evaluate def __init__(self, num_nodes, h_dim, out_dim, num_rels):
from model import RGCN super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
def collect_eval(n_gpus, queue, labels): # two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
def evaluate(model, labels, dataloader, inv_target):
model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
for _ in range(n_gpus): with torch.no_grad():
eval_l, eval_s = queue.get() for input_nodes, output_nodes, blocks in dataloader:
eval_logits.append(eval_l) output_nodes = inv_target[output_nodes]
eval_seeds.append(eval_s) for block in blocks:
eval_logits = th.cat(eval_logits) block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
eval_seeds = th.cat(eval_seeds) logits = model(blocks)
eval_acc = accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item() eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach())
return eval_acc eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds)
def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None): num_seeds = len(eval_seeds)
dev_id = devices[proc_id] loc_sum = accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()) * float(num_seeds)
th.cuda.set_device(dev_id) return torch.tensor([loc_sum.item(), float(num_seeds)])
g, num_rels, num_classes, labels, train_idx, test_idx,\
target_idx, inv_target = dataset def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
# define loss function and optimizer
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( loss_fcn = nn.CrossEntropyLoss()
master_ip='127.0.0.1', master_port='12345') optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
backend = 'nccl' # construct sampler and dataloader
if proc_id == 0: sampler = MultiLayerNeighborSampler([4, 4])
print("backend using {}".format(backend)) train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
th.distributed.init_process_group(backend=backend, batch_size=100, shuffle=True, use_ddp=True)
init_method=dist_init_method, # no separate validation subset, use train index instead for validation
world_size=n_gpus, val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
rank=proc_id) batch_size=100, shuffle=False, use_ddp=True)
for epoch in range(50):
device = th.device(dev_id) model.train()
use_ddp = True if n_gpus > 1 else False total_loss = 0
train_loader, val_loader, test_loader = init_dataloaders( for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
args, g, train_idx, test_idx, target_idx, dev_id, use_ddp=use_ddp) output_nodes = inv_target[output_nodes]
for block in blocks:
model = RGCN(g.num_nodes(), block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
args.n_hidden, logits = model(blocks)
num_classes, loss = loss_fcn(logits, labels[output_nodes])
num_rels, optimizer.zero_grad()
num_bases=args.n_bases, loss.backward()
dropout=args.dropout, optimizer.step()
self_loop=args.use_self_loop, total_loss += loss.item()
ns_mode=True) # torchmetric accuracy defined as num_correct_labels / num_train_nodes
# loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(device)
dist.reduce(loc_acc_split, 0)
if (proc_id == 0):
acc = loc_acc_split[0] / loc_acc_split[1]
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
. format(epoch, total_loss / (it+1), acc.item()))
def run(proc_id, nprocs, devices, g, data):
# find corresponding device for my rank
device = devices[proc_id]
torch.cuda.set_device(device)
# initialize process group and unpack data for sub-processes
dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=proc_id)
num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = data
labels = labels.to(device) labels = labels.to(device)
model = model.to(device)
inv_target = inv_target.to(device) inv_target = inv_target.to(device)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) # create RGCN model (distributed)
in_size = g.num_nodes()
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd) out_size = num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
th.set_num_threads(n_cpus) model = DistributedDataParallel(model, device_ids=[device], output_device=device)
for epoch in range(args.n_epochs): # training + testing
train_acc, loss = train(model, train_loader, inv_target, train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model)
labels, optimizer) test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device,
if proc_id == 0: batch_size=32, shuffle=False, use_ddp=True)
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device)
epoch, args.n_epochs, train_acc, loss)) dist.reduce(loc_acc_split, 0)
if (proc_id == 0):
# garbage collection that empties the queue acc = loc_acc_split[0] / loc_acc_split[1]
gc.collect() print("Test accuracy {:.4f}".format(acc))
# cleanup process group
val_logits, val_seeds = evaluate(model, val_loader, inv_target) dist.destroy_process_group()
queue.put((val_logits, val_seeds))
# gather evaluation result from multiple processes
if proc_id == 0:
val_acc = collect_eval(n_gpus, queue, labels)
print("Validation Accuracy: {:.4f}".format(val_acc))
# garbage collection that empties the queue
gc.collect()
test_logits, test_seeds = evaluate(model, test_loader, inv_target)
queue.put((test_logits, test_seeds))
if proc_id == 0:
test_acc = collect_eval(n_gpus, queue, labels)
print("Final Test Accuracy: {:.4f}".format(test_acc))
th.distributed.barrier()
def main(args, devices):
data = load_data(args.dataset, inv_target=True)
# Create csr/coo/csc formats before launching training processes.
# This avoids creating certain formats in each sub-process, which saves momory and CPU.
g = data[0]
g.create_formats_()
n_gpus = len(devices)
# required for mp.Queue() to work with mp.spawn()
mp.set_start_method('spawn')
n_cpus = mp.cpu_count()
queue = mp.Queue(n_gpus)
mp.spawn(run, args=(n_gpus, n_cpus // n_gpus, args, devices, data, queue),
nprocs=n_gpus)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling and multiple gpus') parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling (multi-gpu)')
parser.add_argument("--dropout", type=float, default=0, parser.add_argument("--dataset", type=str, default="aifb",
help="dropout probability") help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=str, default='0', parser.add_argument("--gpu", type=str, default='0',
help="gpu") help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
parser.add_argument("--n-bases", type=int, default=-1, " e.g., 0,1,2,3.")
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size. ")
args = parser.parse_args() args = parser.parse_args()
devices = list(map(int, args.gpu.split(','))) devices = list(map(int, args.gpu.split(',')))
nprocs = len(devices)
print(f'Training with DGL built-in RGCN module with sampling using', nprocs, f'GPU(s)')
# load and preprocess dataset at master(parent) process
if args.dataset == 'aifb':
data = AIFBDataset()
elif args.dataset == 'mutag':
data = MUTAGDataset()
elif args.dataset == 'bgs':
data = BGSDataset()
elif args.dataset == 'am':
data = AMDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
num_rels = len(g.canonical_etypes)
category = data.predict_category
labels = g.nodes[category].data.pop('labels')
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
# find target category and node id
category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g)
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# rename the fields as they can be changed by DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID)
# find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype)
# avoid creating certain graph formats and train/test indexes in each sub-process to save momory
g.create_formats_()
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
# thread limiting to avoid resource competition
os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // nprocs)
data = num_rels, data.num_classes, labels, train_idx, test_idx, target_idx, inv_target
mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)
print(args)
main(args, devices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment