Unverified Commit 4ef01dbb authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Example] Rgcn support ogbn-mag dataset. (#1812)



* rgcn support ogbn-mag dataset

* upd

* multi-gpu val and test

* Fix

* fix

* Add support for ogbn-mag

* Fix

* Fix

* Fix

* Fix

* Add layer_norm

* update

* Fix merge

* Clean some code

* update Readme

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-87-240.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent e7515773
...@@ -40,7 +40,7 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test ...@@ -40,7 +40,7 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test
### Entity Classification with minibatch ### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 90.56%, best 94.44% (DGL) AIFB: accuracy avg(5 runs) 90.56%, best 94.44% (DGL)
``` ```
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout=20 --batch-size 128 python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout='20,20' --batch-size 128
``` ```
MUTAG: accuracy avg(5 runs) 66.77%, best 69.12% (DGL) MUTAG: accuracy avg(5 runs) 66.77%, best 69.12% (DGL)
...@@ -49,16 +49,30 @@ python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gp ...@@ -49,16 +49,30 @@ python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gp
``` ```
BGS: accuracy avg(5 runs) 91.72%, best 96.55% (DGL) BGS: accuracy avg(5 runs) 91.72%, best 96.55% (DGL)
``` ```
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout 40 --n-epochs=40 --batch-size=128 python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '40,40' --n-epochs=40 --batch-size=128
``` ```
AM: accuracy avg(5 runs) 88.28%, best 90.40% (DGL) AM: accuracy avg(5 runs) 88.28%, best 90.40% (DGL)
``` ```
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout 35 --batch-size 256 --lr 1e-2 --n-hidden 16 --use-self-loop --n-epochs=40 python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '35,35' --batch-size 256 --lr 1e-2 --n-hidden 16 --use-self-loop --n-epochs=40
```
### Entity Classification on OGBN-MAG
Test-bd: P3-8xlarge
OGBN-MAG accuracy 46.22
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,30' --batch-size 512 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --node-feats --layer-norm
``` ```
OGBN-MAG without node-feats 43.24
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,25' --batch-size 256 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --layer-norm
```
Test-bd: P2-8xlarge
### Link Prediction ### Link Prediction
FB15k-237: MRR 0.151 (DGL), 0.158 (paper) FB15k-237: MRR 0.151 (DGL), 0.158 (paper)
``` ```
......
...@@ -10,6 +10,7 @@ import argparse ...@@ -10,6 +10,7 @@ import argparse
import itertools import itertools
import numpy as np import numpy as np
import time import time
import gc
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -25,6 +26,9 @@ from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset ...@@ -25,6 +26,9 @@ from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import RelGraphEmbedLayer from model import RelGraphEmbedLayer
from dgl.nn import RelGraphConv from dgl.nn import RelGraphConv
from utils import thread_wrapped_func from utils import thread_wrapped_func
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
""" Entity classification class for RGCN """ Entity classification class for RGCN
...@@ -62,7 +66,8 @@ class EntityClassify(nn.Module): ...@@ -62,7 +66,8 @@ class EntityClassify(nn.Module):
num_hidden_layers=1, num_hidden_layers=1,
dropout=0, dropout=0,
use_self_loop=False, use_self_loop=False,
low_mem=False): low_mem=False,
layer_norm=False):
super(EntityClassify, self).__init__() super(EntityClassify, self).__init__()
self.device = th.device(device if device >= 0 else 'cpu') self.device = th.device(device if device >= 0 else 'cpu')
self.num_nodes = num_nodes self.num_nodes = num_nodes
...@@ -74,6 +79,7 @@ class EntityClassify(nn.Module): ...@@ -74,6 +79,7 @@ class EntityClassify(nn.Module):
self.dropout = dropout self.dropout = dropout
self.use_self_loop = use_self_loop self.use_self_loop = use_self_loop
self.low_mem = low_mem self.low_mem = low_mem
self.layer_norm = layer_norm
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
...@@ -149,20 +155,50 @@ class NeighborSampler: ...@@ -149,20 +155,50 @@ class NeighborSampler:
norm = self.g.edata['norm'][frontier.edata[dgl.EID]] norm = self.g.edata['norm'][frontier.edata[dgl.EID]]
block = dgl.to_block(frontier, cur) block = dgl.to_block(frontier, cur)
block.srcdata[dgl.NTYPE] = self.g.ndata[dgl.NTYPE][block.srcdata[dgl.NID]] block.srcdata[dgl.NTYPE] = self.g.ndata[dgl.NTYPE][block.srcdata[dgl.NID]]
block.srcdata['type_id'] =self.g.ndata[dgl.NID][block.srcdata[dgl.NID]]
block.edata['etype'] = etypes block.edata['etype'] = etypes
block.edata['norm'] = norm block.edata['norm'] = norm
cur = block.srcdata[dgl.NID] cur = block.srcdata[dgl.NID]
blocks.insert(0, block) blocks.insert(0, block)
return seeds, blocks return seeds, blocks
def evaluate(model, embed_layer, eval_loader, node_feats):
model.eval()
embed_layer.eval()
eval_logits = []
eval_seeds = []
with th.no_grad():
for sample_data in tqdm.tqdm(eval_loader):
th.cuda.empty_cache()
seeds, blocks = sample_data
feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata[dgl.NTYPE],
blocks[0].srcdata['type_id'],
node_feats)
logits = model(blocks, feats)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach())
eval_logits = th.cat(eval_logits)
eval_seeds = th.cat(eval_seeds)
return eval_logits, eval_seeds
@thread_wrapped_func @thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset): def run(proc_id, n_gpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id] dev_id = devices[proc_id]
g, num_of_ntype, num_classes, num_rels, target_idx, \ g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
train_idx, val_idx, test_idx, labels = dataset train_idx, val_idx, test_idx, labels = dataset
if split is not None:
train_seed, val_seed, test_seed = split
train_idx = train_idx[train_seed]
val_idx = val_idx[val_seed]
test_idx = test_idx[test_seed]
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
node_tids = g.ndata[dgl.NTYPE] node_tids = g.ndata[dgl.NTYPE]
sampler = NeighborSampler(g, target_idx, [args.fanout] * args.n_layers) sampler = NeighborSampler(g, target_idx, fanouts)
loader = DataLoader(dataset=train_idx.numpy(), loader = DataLoader(dataset=train_idx.numpy(),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
...@@ -172,7 +208,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -172,7 +208,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# validation sampler # validation sampler
val_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers) val_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers)
val_loader = DataLoader(dataset=val_idx.numpy(), val_loader = DataLoader(dataset=val_idx.numpy(),
batch_size=args.batch_size, batch_size=args.eval_batch_size,
collate_fn=val_sampler.sample_blocks, collate_fn=val_sampler.sample_blocks,
shuffle=False, shuffle=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -180,7 +216,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -180,7 +216,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# validation sampler # validation sampler
test_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers) test_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers)
test_loader = DataLoader(dataset=test_idx.numpy(), test_loader = DataLoader(dataset=test_idx.numpy(),
batch_size=args.batch_size, batch_size=args.eval_batch_size,
collate_fn=test_sampler.sample_blocks, collate_fn=test_sampler.sample_blocks,
shuffle=False, shuffle=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -190,7 +226,9 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -190,7 +226,9 @@ def run(proc_id, n_gpus, args, devices, dataset):
master_ip='127.0.0.1', master_port='12345') master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus world_size = n_gpus
backend = 'nccl' backend = 'nccl'
if args.sparse_embedding:
# using sparse embedding or usig mix_cpu_gpu model (embedding model can not be stored in GPU)
if args.sparse_embedding or args.mix_cpu_gpu:
backend = 'gloo' backend = 'gloo'
th.distributed.init_process_group(backend=backend, th.distributed.init_process_group(backend=backend,
init_method=dist_init_method, init_method=dist_init_method,
...@@ -199,7 +237,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -199,7 +237,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# node features # node features
# None for one-hot feature, if not none, it should be the feature tensor. # None for one-hot feature, if not none, it should be the feature tensor.
node_feats = [None] * num_of_ntype #
embed_layer = RelGraphEmbedLayer(dev_id, embed_layer = RelGraphEmbedLayer(dev_id,
g.number_of_nodes(), g.number_of_nodes(),
node_tids, node_tids,
...@@ -209,6 +247,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -209,6 +247,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
sparse_emb=args.sparse_embedding) sparse_emb=args.sparse_embedding)
# create model # create model
# all model params are in device.
model = EntityClassify(dev_id, model = EntityClassify(dev_id,
g.number_of_nodes(), g.number_of_nodes(),
args.n_hidden, args.n_hidden,
...@@ -218,9 +257,10 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -218,9 +257,10 @@ def run(proc_id, n_gpus, args, devices, dataset):
num_hidden_layers=args.n_layers - 2, num_hidden_layers=args.n_layers - 2,
dropout=args.dropout, dropout=args.dropout,
use_self_loop=args.use_self_loop, use_self_loop=args.use_self_loop,
low_mem=args.low_mem) low_mem=args.low_mem,
layer_norm=args.layer_norm)
if dev_id >= 0: if dev_id >= 0 and n_gpus == 1:
th.cuda.set_device(dev_id) th.cuda.set_device(dev_id)
labels = labels.to(dev_id) labels = labels.to(dev_id)
model.cuda(dev_id) model.cuda(dev_id)
...@@ -229,15 +269,30 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -229,15 +269,30 @@ def run(proc_id, n_gpus, args, devices, dataset):
embed_layer.cuda(dev_id) embed_layer.cuda(dev_id)
if n_gpus > 1: if n_gpus > 1:
embed_layer = DistributedDataParallel(embed_layer, device_ids=[dev_id], output_device=dev_id) labels = labels.to(dev_id)
model.cuda(dev_id)
if args.mix_cpu_gpu:
embed_layer = DistributedDataParallel(embed_layer, device_ids=None, output_device=None)
else:
embed_layer.cuda(dev_id)
embed_layer = DistributedDataParallel(embed_layer, device_ids=[dev_id], output_device=dev_id)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
# optimizer # optimizer
if args.sparse_embedding: if args.sparse_embedding:
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm) dense_params = list(model.parameters())
emb_optimizer = th.optim.SparseAdam(embed_layer.parameters(), lr=args.lr) if args.node_feats:
if n_gpus > 1:
dense_params += list(embed_layer.module.embeds.parameters())
else:
dense_params += list(embed_layer.embeds.parameters())
optimizer = th.optim.Adam(dense_params, lr=args.lr, weight_decay=args.l2norm)
if n_gpus > 1:
emb_optimizer = th.optim.SparseAdam(embed_layer.module.node_embeds.parameters(), lr=args.lr)
else:
emb_optimizer = th.optim.SparseAdam(embed_layer.node_embeds.parameters(), lr=args.lr)
else: else:
all_params = itertools.chain(model.parameters(), embed_layer.parameters()) all_params = list(model.parameters()) + list(embed_layer.parameters())
optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)
# training loop # training loop
...@@ -247,20 +302,22 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -247,20 +302,22 @@ def run(proc_id, n_gpus, args, devices, dataset):
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
model.train() model.train()
optimizer.zero_grad() embed_layer.train()
if args.sparse_embedding:
emb_optimizer.zero_grad()
for i, sample_data in enumerate(loader): for i, sample_data in enumerate(loader):
seeds, blocks = sample_data seeds, blocks = sample_data
t0 = time.time() t0 = time.time()
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id), feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata[dgl.NTYPE].to(dev_id), blocks[0].srcdata[dgl.NTYPE],
blocks[0].srcdata['type_id'],
node_feats) node_feats)
logits = model(blocks, feats) logits = model(blocks, feats)
loss = F.cross_entropy(logits, labels[seeds]) loss = F.cross_entropy(logits, labels[seeds])
t1 = time.time() t1 = time.time()
optimizer.zero_grad()
if args.sparse_embedding:
emb_optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if args.sparse_embedding: if args.sparse_embedding:
...@@ -269,54 +326,64 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -269,54 +326,64 @@ def run(proc_id, n_gpus, args, devices, dataset):
forward_time.append(t1 - t0) forward_time.append(t1 - t0)
backward_time.append(t2 - t1) backward_time.append(t2 - t1)
print("Epoch {:05d}:{:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".
format(epoch, i, forward_time[-1], backward_time[-1]))
train_acc = th.sum(logits.argmax(dim=1) == labels[seeds]).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == labels[seeds]).item() / len(seeds)
print("Train Accuracy: {:.4f} | Train Loss: {:.4f}". if i % 100 and proc_id == 0:
format(train_acc, loss.item())) print("Train Accuracy: {:.4f} | Train Loss: {:.4f}".
format(train_acc, loss.item()))
# only process 0 will do the evaluation print("Epoch {:05d}:{:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".
if proc_id == 0: format(epoch, i, forward_time[-1], backward_time[-1]))
model.eval()
eval_logtis = [] if (queue is not None) or (proc_id == 0):
eval_seeds = [] val_logits, val_seeds = evaluate(model, embed_layer, val_loader, node_feats)
for i, sample_data in enumerate(val_loader): if queue is not None:
seeds, blocks = sample_data queue.put((val_logits, val_seeds))
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id),
blocks[0].srcdata[dgl.NTYPE].to(dev_id), # gather evaluation result from multiple processes
node_feats) if proc_id == 0:
logits = model(blocks, feats) if queue is not None:
eval_logtis.append(logits) val_logits = []
eval_seeds.append(seeds) val_seeds = []
eval_logtis = th.cat(eval_logtis) for i in range(n_gpus):
eval_seeds = th.cat(eval_seeds) log = queue.get()
val_loss = F.cross_entropy(eval_logtis, labels[eval_seeds]) val_l, val_s = log
val_acc = th.sum(eval_logtis.argmax(dim=1) == labels[eval_seeds]).item() / len(eval_seeds) val_logits.append(val_l)
print("Validation Accuracy: {:.4f} | Validation loss: {:.4f}". val_seeds.append(val_s)
format(val_acc, val_loss.item())) val_logits = th.cat(val_logits)
val_seeds = th.cat(val_seeds)
val_loss = F.cross_entropy(val_logits, labels[val_seeds].cpu()).item()
val_acc = th.sum(val_logits.argmax(dim=1) == labels[val_seeds].cpu()).item() / len(val_seeds)
print("Validation Accuracy: {:.4f} | Validation loss: {:.4f}".
format(val_acc, val_loss))
if n_gpus > 1: if n_gpus > 1:
th.distributed.barrier() th.distributed.barrier()
print()
# only process 0 will do the evaluation
# only process 0 will do the testing if (queue is not None) or (proc_id == 0):
if proc_id == 0: test_logits, test_seeds = evaluate(model, embed_layer, test_loader, node_feats)
model.eval() if queue is not None:
test_logtis = [] queue.put((test_logits, test_seeds))
test_seeds = []
for i, sample_data in enumerate(test_loader): # gather evaluation result from multiple processes
seeds, blocks = sample_data if proc_id == 0:
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id), if queue is not None:
blocks[0].srcdata[dgl.NTYPE].to(dev_id), test_logits = []
[None] * num_of_ntype) test_seeds = []
logits = model(blocks, feats) for i in range(n_gpus):
test_logtis.append(logits) log = queue.get()
test_seeds.append(seeds) test_l, test_s = log
test_logtis = th.cat(test_logtis) test_logits.append(test_l)
test_seeds = th.cat(test_seeds) test_seeds.append(test_s)
test_loss = F.cross_entropy(test_logtis, labels[test_seeds]) test_logits = th.cat(test_logits)
test_acc = th.sum(test_logtis.argmax(dim=1) == labels[test_seeds]).item() / len(test_seeds) test_seeds = th.cat(test_seeds)
print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) test_loss = F.cross_entropy(test_logits, labels[test_seeds].cpu()).item()
print() test_acc = th.sum(test_logits.argmax(dim=1) == labels[test_seeds].cpu()).item() / len(test_seeds)
print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss))
print()
# sync for test
if n_gpus > 1:
th.distributed.barrier()
print("{}/{} Mean forward time: {:4f}".format(proc_id, n_gpus, print("{}/{} Mean forward time: {:4f}".format(proc_id, n_gpus,
np.mean(forward_time[len(forward_time) // 4:]))) np.mean(forward_time[len(forward_time) // 4:])))
...@@ -334,37 +401,82 @@ def main(args, devices): ...@@ -334,37 +401,82 @@ def main(args, devices):
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == 'am': elif args.dataset == 'am':
dataset = AMDataset() dataset = AMDataset()
elif args.dataset == 'ogbn-mag':
dataset = DglNodePropPredDataset(name=args.dataset)
ogb_dataset = True
else: else:
raise ValueError() raise ValueError()
# Load from hetero-graph if ogb_dataset is True:
hg = dataset[0] split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]['paper']
num_rels = len(hg.canonical_etypes) val_idx = split_idx["valid"]['paper']
num_of_ntype = len(hg.ntypes) test_idx = split_idx["test"]['paper']
category = dataset.predict_category hg_orig, labels = dataset[0]
num_classes = dataset.num_classes subgs = {}
train_mask = hg.nodes[category].data.pop('train_mask') for etype in hg_orig.canonical_etypes:
test_mask = hg.nodes[category].data.pop('test_mask') u, v = hg_orig.all_edges(etype=etype)
labels = hg.nodes[category].data.pop('labels') subgs[etype] = (u, v)
train_idx = th.nonzero(train_mask).squeeze() subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u)
test_idx = th.nonzero(test_mask).squeeze() hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat']
# split dataset into train, validate, test labels = labels['paper'].squeeze()
if args.validation:
val_idx = train_idx[:len(train_idx) // 5] num_rels = len(hg.canonical_etypes)
train_idx = train_idx[len(train_idx) // 5:] num_of_ntype = len(hg.ntypes)
num_classes = dataset.num_classes
if args.dataset == 'ogbn-mag':
category = 'paper'
print('Number of relations: {}'.format(num_rels))
print('Number of class: {}'.format(num_classes))
print('Number of train: {}'.format(len(train_idx)))
print('Number of valid: {}'.format(len(val_idx)))
print('Number of test: {}'.format(len(test_idx)))
if args.node_feats:
node_feats = []
for ntype in hg.ntypes:
if len(hg.nodes[ntype].data) == 0:
node_feats.append(None)
else:
assert len(hg.nodes[ntype].data) == 1
feat = hg.nodes[ntype].data.pop('feat')
node_feats.append(feat.share_memory_())
else:
node_feats = [None] * num_of_ntype
else: else:
val_idx = train_idx # Load from hetero-graph
hg = dataset[0]
num_rels = len(hg.canonical_etypes)
num_of_ntype = len(hg.ntypes)
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = hg.nodes[category].data.pop('train_mask')
test_mask = hg.nodes[category].data.pop('test_mask')
labels = hg.nodes[category].data.pop('labels')
train_idx = th.nonzero(train_mask).squeeze()
test_idx = th.nonzero(test_mask).squeeze()
node_feats = [None] * num_of_ntype
# AIFB, MUTAG, BGS and AM datasets do not provide validation set split.
# Split train set into train and validation if args.validation is set
# otherwise use train set as the validation set.
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
else:
val_idx = train_idx
# calculate norm for each edge type and store in edge # calculate norm for each edge type and store in edge
for canonical_etype in hg.canonical_etypes: if args.global_norm is False:
u, v, eid = hg.all_edges(form='all', etype=canonical_etype) for canonical_etype in hg.canonical_etypes:
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True) u, v, eid = hg.all_edges(form='all', etype=canonical_etype)
degrees = count[inverse_index] _, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
norm = th.ones(eid.shape[0]) / degrees degrees = count[inverse_index]
norm = norm.unsqueeze(1) norm = th.ones(eid.shape[0]) / degrees
hg.edges[canonical_etype].data['norm'] = norm norm = norm.unsqueeze(1)
hg.edges[canonical_etype].data['norm'] = norm
# get target category id # get target category id
category_id = len(hg.ntypes) category_id = len(hg.ntypes)
...@@ -373,6 +485,14 @@ def main(args, devices): ...@@ -373,6 +485,14 @@ def main(args, devices):
category_id = i category_id = i
g = dgl.to_homo(hg) g = dgl.to_homo(hg)
if args.global_norm:
u, v, eid = g.all_edges(form='all')
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
degrees = count[inverse_index]
norm = th.ones(eid.shape[0]) / degrees
norm = norm.unsqueeze(1)
g.edata['norm'] = norm
g.ndata[dgl.NTYPE].share_memory_() g.ndata[dgl.NTYPE].share_memory_()
g.edata[dgl.ETYPE].share_memory_() g.edata[dgl.ETYPE].share_memory_()
g.edata['norm'].share_memory_() g.edata['norm'].share_memory_()
...@@ -383,31 +503,54 @@ def main(args, devices): ...@@ -383,31 +503,54 @@ def main(args, devices):
loc = (node_tids == category_id) loc = (node_tids == category_id)
target_idx = node_ids[loc] target_idx = node_ids[loc]
target_idx.share_memory_() target_idx.share_memory_()
train_idx.share_memory_()
val_idx.share_memory_()
test_idx.share_memory_()
n_gpus = len(devices) n_gpus = len(devices)
# cpu # cpu
if devices[0] == -1: if devices[0] == -1:
run(0, 0, args, ['cpu'], run(0, 0, args, ['cpu'],
(g, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels)) train_idx, val_idx, test_idx, labels), None, None)
# gpu # gpu
elif n_gpus == 1: elif n_gpus == 1:
run(0, n_gpus, args, devices, run(0, n_gpus, args, devices,
(g, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels)) train_idx, val_idx, test_idx, labels), None, None)
# multi gpu # multi gpu
else: else:
queue = mp.Queue(n_gpus)
procs = [] procs = []
num_train_seeds = train_idx.shape[0] num_train_seeds = train_idx.shape[0]
num_valid_seeds = val_idx.shape[0]
num_test_seeds = test_idx.shape[0]
train_seeds = th.randperm(num_train_seeds)
valid_seeds = th.randperm(num_valid_seeds)
test_seeds = th.randperm(num_test_seeds)
tseeds_per_proc = num_train_seeds // n_gpus tseeds_per_proc = num_train_seeds // n_gpus
vseeds_per_proc = num_valid_seeds // n_gpus
tstseeds_per_proc = num_test_seeds // n_gpus
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
proc_train_seeds = train_idx[proc_id * tseeds_per_proc : # we have multi-gpu for training, evaluation and testing
(proc_id + 1) * tseeds_per_proc \ # so split trian set, valid set and test set into num-of-gpu parts.
if (proc_id + 1) * tseeds_per_proc < num_train_seeds \ proc_train_seeds = train_seeds[proc_id * tseeds_per_proc :
else num_train_seeds] (proc_id + 1) * tseeds_per_proc \
if (proc_id + 1) * tseeds_per_proc < num_train_seeds \
else num_train_seeds]
proc_valid_seeds = valid_seeds[proc_id * vseeds_per_proc :
(proc_id + 1) * vseeds_per_proc \
if (proc_id + 1) * vseeds_per_proc < num_valid_seeds \
else num_valid_seeds]
proc_test_seeds = test_seeds[proc_id * tstseeds_per_proc :
(proc_id + 1) * tstseeds_per_proc \
if (proc_id + 1) * tstseeds_per_proc < num_test_seeds \
else num_test_seeds]
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices,
(g, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
proc_train_seeds, val_idx, test_idx, labels))) train_idx, val_idx, test_idx, labels),
(proc_train_seeds, proc_valid_seeds, proc_test_seeds),
queue))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
...@@ -436,7 +579,7 @@ def config(): ...@@ -436,7 +579,7 @@ def config():
help="l2 norm coef") help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true', parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel") help="remove untouched nodes and relabel")
parser.add_argument("--fanout", type=int, default=4, parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling.") help="Fan-out of neighbor sampling.")
parser.add_argument("--use-self-loop", default=False, action='store_true', parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation") help="include self feature as a special relation")
...@@ -445,6 +588,8 @@ def config(): ...@@ -445,6 +588,8 @@ def config():
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument('--testing', dest='validation', action='store_false')
parser.add_argument("--batch-size", type=int, default=100, parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size. ") help="Mini-batch size. ")
parser.add_argument("--eval-batch-size", type=int, default=128,
help="Mini-batch size. ")
parser.add_argument("--num-workers", type=int, default=0, parser.add_argument("--num-workers", type=int, default=0,
help="Number of workers for dataloader.") help="Number of workers for dataloader.")
parser.add_argument("--low-mem", default=False, action='store_true', parser.add_argument("--low-mem", default=False, action='store_true',
...@@ -453,6 +598,12 @@ def config(): ...@@ -453,6 +598,12 @@ def config():
help="Whether store node embeddins in cpu") help="Whether store node embeddins in cpu")
parser.add_argument("--sparse-embedding", action='store_true', parser.add_argument("--sparse-embedding", action='store_true',
help='Use sparse embedding for node embeddings.') help='Use sparse embedding for node embeddings.')
parser.add_argument('--node-feats', default=False, action='store_true',
help='Whether use node features')
parser.add_argument('--global-norm', default=False, action='store_true',
help='User global norm instead of per node type norm')
parser.add_argument('--layer-norm', default=False, action='store_true',
help='Use layer norm')
parser.set_defaults(validation=True) parser.set_defaults(validation=True)
args = parser.parse_args() args = parser.parse_args()
return args return args
......
...@@ -61,7 +61,7 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -61,7 +61,7 @@ class RelGraphEmbedLayer(nn.Module):
num_of_ntype : int num_of_ntype : int
Number of node types Number of node types
input_size : list of int input_size : list of int
A list of input feature size for each node type. If None, we then A list of input feature size for each node type. If None, we then
treat certain input feature as an one-hot encoding feature. treat certain input feature as an one-hot encoding feature.
embed_size : int embed_size : int
Output embed size Output embed size
...@@ -91,16 +91,15 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -91,16 +91,15 @@ class RelGraphEmbedLayer(nn.Module):
for ntype in range(num_of_ntype): for ntype in range(num_of_ntype):
if input_size[ntype] is not None: if input_size[ntype] is not None:
loc = node_tids == ntype input_emb_size = input_size[ntype].shape[1]
input_emb_size = node_tids[loc].shape[0]
embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size)) embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(embed)
self.embeds[str(ntype)] = embed self.embeds[str(ntype)] = embed
self.node_embeds = th.nn.Embedding(node_tids.shape[0], self.embed_size, sparse=self.sparse_emb) self.node_embeds = th.nn.Embedding(node_tids.shape[0], self.embed_size, sparse=self.sparse_emb)
nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0) nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0)
def forward(self, node_ids, node_tids, features): def forward(self, node_ids, node_tids, type_ids, features):
"""Forward computation """Forward computation
Parameters Parameters
---------- ----------
...@@ -111,19 +110,21 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -111,19 +110,21 @@ class RelGraphEmbedLayer(nn.Module):
features : list of features features : list of features
list of initial features for nodes belong to different node type. list of initial features for nodes belong to different node type.
If None, the corresponding features is an one-hot encoding feature, If None, the corresponding features is an one-hot encoding feature,
else use the features directly as input feature and matmul a else use the features directly as input feature and matmul a
projection matrix. projection matrix.
Returns Returns
------- -------
tensor tensor
embeddings as the input of the next layer embeddings as the input of the next layer
""" """
tsd_idx = node_ids < self.num_nodes tsd_ids = node_ids.to(self.node_embeds.weight.device)
tsd_ids = node_ids[tsd_idx] embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.dev_id)
embeds = self.node_embeds(tsd_ids)
for ntype in range(self.num_of_ntype): for ntype in range(self.num_of_ntype):
if features[ntype] is not None: if features[ntype] is not None:
loc = node_tids == ntype loc = node_tids == ntype
embeds[loc] = features[ntype] @ self.embeds[str(ntype)] embeds[loc] = features[ntype][type_ids[loc]].to(self.dev_id) @ self.embeds[str(ntype)].to(self.dev_id)
else:
loc = node_tids == ntype
embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.dev_id)
return embeds.to(self.dev_id) return embeds
...@@ -61,6 +61,8 @@ class RelGraphConv(gluon.Block): ...@@ -61,6 +61,8 @@ class RelGraphConv(gluon.Block):
Default: False. Default: False.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
""" """
def __init__(self, def __init__(self,
in_feat, in_feat,
...@@ -72,7 +74,8 @@ class RelGraphConv(gluon.Block): ...@@ -72,7 +74,8 @@ class RelGraphConv(gluon.Block):
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False, low_mem=False,
dropout=0.0): dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -86,6 +89,7 @@ class RelGraphConv(gluon.Block): ...@@ -86,6 +89,7 @@ class RelGraphConv(gluon.Block):
self.self_loop = self_loop self.self_loop = self_loop
assert low_mem is False, 'MXNet currently does not support low-memory implementation.' assert low_mem is False, 'MXNet currently does not support low-memory implementation.'
assert layer_norm is False, 'MXNet currently does not support layer norm.'
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
......
...@@ -59,6 +59,8 @@ class RelGraphConv(nn.Module): ...@@ -59,6 +59,8 @@ class RelGraphConv(nn.Module):
Turn it on when you encounter OOM problem during training or evaluation. Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
""" """
def __init__(self, def __init__(self,
in_feat, in_feat,
...@@ -70,7 +72,8 @@ class RelGraphConv(nn.Module): ...@@ -70,7 +72,8 @@ class RelGraphConv(nn.Module):
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False, low_mem=False,
dropout=0.0): dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -83,6 +86,7 @@ class RelGraphConv(nn.Module): ...@@ -83,6 +86,7 @@ class RelGraphConv(nn.Module):
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem self.low_mem = low_mem
self.layer_norm = layer_norm
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
...@@ -120,6 +124,10 @@ class RelGraphConv(nn.Module): ...@@ -120,6 +124,10 @@ class RelGraphConv(nn.Module):
self.h_bias = nn.Parameter(th.Tensor(out_feat)) self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias) nn.init.zeros_(self.h_bias)
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(n_hidden, elementwise_affine=True)
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
...@@ -219,6 +227,8 @@ class RelGraphConv(nn.Module): ...@@ -219,6 +227,8 @@ class RelGraphConv(nn.Module):
g.update_all(self.message_func, fn.sum(msg='msg', out='h')) g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation # apply bias and activation
node_repr = g.dstdata['h'] node_repr = g.dstdata['h']
if self.layer_norm:
node_repr = self.layer_norm_weight(node_repr)
if self.bias: if self.bias:
node_repr = node_repr + self.h_bias node_repr = node_repr + self.h_bias
if self.self_loop: if self.self_loop:
......
...@@ -59,6 +59,8 @@ class RelGraphConv(layers.Layer): ...@@ -59,6 +59,8 @@ class RelGraphConv(layers.Layer):
Turn it on when you encounter OOM problem during training or evaluation. Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
""" """
def __init__(self, def __init__(self,
...@@ -71,7 +73,8 @@ class RelGraphConv(layers.Layer): ...@@ -71,7 +73,8 @@ class RelGraphConv(layers.Layer):
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False, low_mem=False,
dropout=0.0): dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -85,6 +88,8 @@ class RelGraphConv(layers.Layer): ...@@ -85,6 +88,8 @@ class RelGraphConv(layers.Layer):
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem self.low_mem = low_mem
assert layer_norm is False, 'TensorFlow currently does not support layer norm.'
xinit = tf.keras.initializers.glorot_uniform() xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment