"examples/vscode:/vscode.git/clone" did not exist on "a29ea36d62b294075fb1a3632927f8dc2badc85d"
Unverified Commit 4ef01dbb authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Example] Rgcn support ogbn-mag dataset. (#1812)



* rgcn support ogbn-mag dataset

* upd

* multi-gpu val and test

* Fix

* fix

* Add support for ogbn-mag

* Fix

* Fix

* Fix

* Fix

* Add layer_norm

* update

* Fix merge

* Clean some code

* update Readme

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-87-240.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent e7515773
......@@ -40,7 +40,7 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test
### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 90.56%, best 94.44% (DGL)
```
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout=20 --batch-size 128
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout='20,20' --batch-size 128
```
MUTAG: accuracy avg(5 runs) 66.77%, best 69.12% (DGL)
......@@ -49,16 +49,30 @@ python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gp
```
BGS: accuracy avg(5 runs) 91.72%, best 96.55% (DGL)
```
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout 40 --n-epochs=40 --batch-size=128
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '40,40' --n-epochs=40 --batch-size=128
```
AM: accuracy avg(5 runs) 88.28%, best 90.40% (DGL)
```
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout 35 --batch-size 256 --lr 1e-2 --n-hidden 16 --use-self-loop --n-epochs=40
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '35,35' --batch-size 256 --lr 1e-2 --n-hidden 16 --use-self-loop --n-epochs=40
```
### Entity Classification on OGBN-MAG
Test-bd: P3-8xlarge
OGBN-MAG accuracy 46.22
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,30' --batch-size 512 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --node-feats --layer-norm
```
OGBN-MAG without node-feats 43.24
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,25' --batch-size 256 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --layer-norm
```
Test-bd: P2-8xlarge
### Link Prediction
FB15k-237: MRR 0.151 (DGL), 0.158 (paper)
```
......
......@@ -10,6 +10,7 @@ import argparse
import itertools
import numpy as np
import time
import gc
import torch as th
import torch.nn as nn
import torch.nn.functional as F
......@@ -25,6 +26,9 @@ from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import RelGraphEmbedLayer
from dgl.nn import RelGraphConv
from utils import thread_wrapped_func
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
class EntityClassify(nn.Module):
""" Entity classification class for RGCN
......@@ -62,7 +66,8 @@ class EntityClassify(nn.Module):
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
low_mem=False):
low_mem=False,
layer_norm=False):
super(EntityClassify, self).__init__()
self.device = th.device(device if device >= 0 else 'cpu')
self.num_nodes = num_nodes
......@@ -74,6 +79,7 @@ class EntityClassify(nn.Module):
self.dropout = dropout
self.use_self_loop = use_self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm
self.layers = nn.ModuleList()
# i2h
......@@ -149,20 +155,50 @@ class NeighborSampler:
norm = self.g.edata['norm'][frontier.edata[dgl.EID]]
block = dgl.to_block(frontier, cur)
block.srcdata[dgl.NTYPE] = self.g.ndata[dgl.NTYPE][block.srcdata[dgl.NID]]
block.srcdata['type_id'] =self.g.ndata[dgl.NID][block.srcdata[dgl.NID]]
block.edata['etype'] = etypes
block.edata['norm'] = norm
cur = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seeds, blocks
def evaluate(model, embed_layer, eval_loader, node_feats):
model.eval()
embed_layer.eval()
eval_logits = []
eval_seeds = []
with th.no_grad():
for sample_data in tqdm.tqdm(eval_loader):
th.cuda.empty_cache()
seeds, blocks = sample_data
feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata[dgl.NTYPE],
blocks[0].srcdata['type_id'],
node_feats)
logits = model(blocks, feats)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach())
eval_logits = th.cat(eval_logits)
eval_seeds = th.cat(eval_seeds)
return eval_logits, eval_seeds
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset):
def run(proc_id, n_gpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id]
g, num_of_ntype, num_classes, num_rels, target_idx, \
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
train_idx, val_idx, test_idx, labels = dataset
if split is not None:
train_seed, val_seed, test_seed = split
train_idx = train_idx[train_seed]
val_idx = val_idx[val_seed]
test_idx = test_idx[test_seed]
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
node_tids = g.ndata[dgl.NTYPE]
sampler = NeighborSampler(g, target_idx, [args.fanout] * args.n_layers)
sampler = NeighborSampler(g, target_idx, fanouts)
loader = DataLoader(dataset=train_idx.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
......@@ -172,7 +208,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# validation sampler
val_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers)
val_loader = DataLoader(dataset=val_idx.numpy(),
batch_size=args.batch_size,
batch_size=args.eval_batch_size,
collate_fn=val_sampler.sample_blocks,
shuffle=False,
num_workers=args.num_workers)
......@@ -180,7 +216,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# validation sampler
test_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers)
test_loader = DataLoader(dataset=test_idx.numpy(),
batch_size=args.batch_size,
batch_size=args.eval_batch_size,
collate_fn=test_sampler.sample_blocks,
shuffle=False,
num_workers=args.num_workers)
......@@ -190,7 +226,9 @@ def run(proc_id, n_gpus, args, devices, dataset):
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
backend = 'nccl'
if args.sparse_embedding:
# using sparse embedding or usig mix_cpu_gpu model (embedding model can not be stored in GPU)
if args.sparse_embedding or args.mix_cpu_gpu:
backend = 'gloo'
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
......@@ -199,7 +237,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
# node features
# None for one-hot feature, if not none, it should be the feature tensor.
node_feats = [None] * num_of_ntype
#
embed_layer = RelGraphEmbedLayer(dev_id,
g.number_of_nodes(),
node_tids,
......@@ -209,6 +247,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
sparse_emb=args.sparse_embedding)
# create model
# all model params are in device.
model = EntityClassify(dev_id,
g.number_of_nodes(),
args.n_hidden,
......@@ -218,9 +257,10 @@ def run(proc_id, n_gpus, args, devices, dataset):
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop,
low_mem=args.low_mem)
low_mem=args.low_mem,
layer_norm=args.layer_norm)
if dev_id >= 0:
if dev_id >= 0 and n_gpus == 1:
th.cuda.set_device(dev_id)
labels = labels.to(dev_id)
model.cuda(dev_id)
......@@ -229,15 +269,30 @@ def run(proc_id, n_gpus, args, devices, dataset):
embed_layer.cuda(dev_id)
if n_gpus > 1:
labels = labels.to(dev_id)
model.cuda(dev_id)
if args.mix_cpu_gpu:
embed_layer = DistributedDataParallel(embed_layer, device_ids=None, output_device=None)
else:
embed_layer.cuda(dev_id)
embed_layer = DistributedDataParallel(embed_layer, device_ids=[dev_id], output_device=dev_id)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
# optimizer
if args.sparse_embedding:
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
emb_optimizer = th.optim.SparseAdam(embed_layer.parameters(), lr=args.lr)
dense_params = list(model.parameters())
if args.node_feats:
if n_gpus > 1:
dense_params += list(embed_layer.module.embeds.parameters())
else:
dense_params += list(embed_layer.embeds.parameters())
optimizer = th.optim.Adam(dense_params, lr=args.lr, weight_decay=args.l2norm)
if n_gpus > 1:
emb_optimizer = th.optim.SparseAdam(embed_layer.module.node_embeds.parameters(), lr=args.lr)
else:
all_params = itertools.chain(model.parameters(), embed_layer.parameters())
emb_optimizer = th.optim.SparseAdam(embed_layer.node_embeds.parameters(), lr=args.lr)
else:
all_params = list(model.parameters()) + list(embed_layer.parameters())
optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)
# training loop
......@@ -247,20 +302,22 @@ def run(proc_id, n_gpus, args, devices, dataset):
for epoch in range(args.n_epochs):
model.train()
optimizer.zero_grad()
if args.sparse_embedding:
emb_optimizer.zero_grad()
embed_layer.train()
for i, sample_data in enumerate(loader):
seeds, blocks = sample_data
t0 = time.time()
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id),
blocks[0].srcdata[dgl.NTYPE].to(dev_id),
feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata[dgl.NTYPE],
blocks[0].srcdata['type_id'],
node_feats)
logits = model(blocks, feats)
loss = F.cross_entropy(logits, labels[seeds])
t1 = time.time()
optimizer.zero_grad()
if args.sparse_embedding:
emb_optimizer.zero_grad()
loss.backward()
optimizer.step()
if args.sparse_embedding:
......@@ -269,55 +326,65 @@ def run(proc_id, n_gpus, args, devices, dataset):
forward_time.append(t1 - t0)
backward_time.append(t2 - t1)
print("Epoch {:05d}:{:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".
format(epoch, i, forward_time[-1], backward_time[-1]))
train_acc = th.sum(logits.argmax(dim=1) == labels[seeds]).item() / len(seeds)
if i % 100 and proc_id == 0:
print("Train Accuracy: {:.4f} | Train Loss: {:.4f}".
format(train_acc, loss.item()))
print("Epoch {:05d}:{:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".
format(epoch, i, forward_time[-1], backward_time[-1]))
# only process 0 will do the evaluation
if (queue is not None) or (proc_id == 0):
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, node_feats)
if queue is not None:
queue.put((val_logits, val_seeds))
# gather evaluation result from multiple processes
if proc_id == 0:
model.eval()
eval_logtis = []
eval_seeds = []
for i, sample_data in enumerate(val_loader):
seeds, blocks = sample_data
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id),
blocks[0].srcdata[dgl.NTYPE].to(dev_id),
node_feats)
logits = model(blocks, feats)
eval_logtis.append(logits)
eval_seeds.append(seeds)
eval_logtis = th.cat(eval_logtis)
eval_seeds = th.cat(eval_seeds)
val_loss = F.cross_entropy(eval_logtis, labels[eval_seeds])
val_acc = th.sum(eval_logtis.argmax(dim=1) == labels[eval_seeds]).item() / len(eval_seeds)
if queue is not None:
val_logits = []
val_seeds = []
for i in range(n_gpus):
log = queue.get()
val_l, val_s = log
val_logits.append(val_l)
val_seeds.append(val_s)
val_logits = th.cat(val_logits)
val_seeds = th.cat(val_seeds)
val_loss = F.cross_entropy(val_logits, labels[val_seeds].cpu()).item()
val_acc = th.sum(val_logits.argmax(dim=1) == labels[val_seeds].cpu()).item() / len(val_seeds)
print("Validation Accuracy: {:.4f} | Validation loss: {:.4f}".
format(val_acc, val_loss.item()))
format(val_acc, val_loss))
if n_gpus > 1:
th.distributed.barrier()
print()
# only process 0 will do the testing
# only process 0 will do the evaluation
if (queue is not None) or (proc_id == 0):
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, node_feats)
if queue is not None:
queue.put((test_logits, test_seeds))
# gather evaluation result from multiple processes
if proc_id == 0:
model.eval()
test_logtis = []
if queue is not None:
test_logits = []
test_seeds = []
for i, sample_data in enumerate(test_loader):
seeds, blocks = sample_data
feats = embed_layer(blocks[0].srcdata[dgl.NID].to(dev_id),
blocks[0].srcdata[dgl.NTYPE].to(dev_id),
[None] * num_of_ntype)
logits = model(blocks, feats)
test_logtis.append(logits)
test_seeds.append(seeds)
test_logtis = th.cat(test_logtis)
for i in range(n_gpus):
log = queue.get()
test_l, test_s = log
test_logits.append(test_l)
test_seeds.append(test_s)
test_logits = th.cat(test_logits)
test_seeds = th.cat(test_seeds)
test_loss = F.cross_entropy(test_logtis, labels[test_seeds])
test_acc = th.sum(test_logtis.argmax(dim=1) == labels[test_seeds]).item() / len(test_seeds)
print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
test_loss = F.cross_entropy(test_logits, labels[test_seeds].cpu()).item()
test_acc = th.sum(test_logits.argmax(dim=1) == labels[test_seeds].cpu()).item() / len(test_seeds)
print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss))
print()
# sync for test
if n_gpus > 1:
th.distributed.barrier()
print("{}/{} Mean forward time: {:4f}".format(proc_id, n_gpus,
np.mean(forward_time[len(forward_time) // 4:])))
print("{}/{} Mean backward time: {:4f}".format(proc_id, n_gpus,
......@@ -334,9 +401,50 @@ def main(args, devices):
dataset = BGSDataset()
elif args.dataset == 'am':
dataset = AMDataset()
elif args.dataset == 'ogbn-mag':
dataset = DglNodePropPredDataset(name=args.dataset)
ogb_dataset = True
else:
raise ValueError()
if ogb_dataset is True:
split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]['paper']
val_idx = split_idx["valid"]['paper']
test_idx = split_idx["test"]['paper']
hg_orig, labels = dataset[0]
subgs = {}
for etype in hg_orig.canonical_etypes:
u, v = hg_orig.all_edges(etype=etype)
subgs[etype] = (u, v)
subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat']
labels = labels['paper'].squeeze()
num_rels = len(hg.canonical_etypes)
num_of_ntype = len(hg.ntypes)
num_classes = dataset.num_classes
if args.dataset == 'ogbn-mag':
category = 'paper'
print('Number of relations: {}'.format(num_rels))
print('Number of class: {}'.format(num_classes))
print('Number of train: {}'.format(len(train_idx)))
print('Number of valid: {}'.format(len(val_idx)))
print('Number of test: {}'.format(len(test_idx)))
if args.node_feats:
node_feats = []
for ntype in hg.ntypes:
if len(hg.nodes[ntype].data) == 0:
node_feats.append(None)
else:
assert len(hg.nodes[ntype].data) == 1
feat = hg.nodes[ntype].data.pop('feat')
node_feats.append(feat.share_memory_())
else:
node_feats = [None] * num_of_ntype
else:
# Load from hetero-graph
hg = dataset[0]
......@@ -349,8 +457,11 @@ def main(args, devices):
labels = hg.nodes[category].data.pop('labels')
train_idx = th.nonzero(train_mask).squeeze()
test_idx = th.nonzero(test_mask).squeeze()
node_feats = [None] * num_of_ntype
# split dataset into train, validate, test
# AIFB, MUTAG, BGS and AM datasets do not provide validation set split.
# Split train set into train and validation if args.validation is set
# otherwise use train set as the validation set.
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
......@@ -358,6 +469,7 @@ def main(args, devices):
val_idx = train_idx
# calculate norm for each edge type and store in edge
if args.global_norm is False:
for canonical_etype in hg.canonical_etypes:
u, v, eid = hg.all_edges(form='all', etype=canonical_etype)
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
......@@ -373,6 +485,14 @@ def main(args, devices):
category_id = i
g = dgl.to_homo(hg)
if args.global_norm:
u, v, eid = g.all_edges(form='all')
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
degrees = count[inverse_index]
norm = th.ones(eid.shape[0]) / degrees
norm = norm.unsqueeze(1)
g.edata['norm'] = norm
g.ndata[dgl.NTYPE].share_memory_()
g.edata[dgl.ETYPE].share_memory_()
g.edata['norm'].share_memory_()
......@@ -383,31 +503,54 @@ def main(args, devices):
loc = (node_tids == category_id)
target_idx = node_ids[loc]
target_idx.share_memory_()
train_idx.share_memory_()
val_idx.share_memory_()
test_idx.share_memory_()
n_gpus = len(devices)
# cpu
if devices[0] == -1:
run(0, 0, args, ['cpu'],
(g, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels))
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels), None, None)
# gpu
elif n_gpus == 1:
run(0, n_gpus, args, devices,
(g, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels))
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels), None, None)
# multi gpu
else:
queue = mp.Queue(n_gpus)
procs = []
num_train_seeds = train_idx.shape[0]
num_valid_seeds = val_idx.shape[0]
num_test_seeds = test_idx.shape[0]
train_seeds = th.randperm(num_train_seeds)
valid_seeds = th.randperm(num_valid_seeds)
test_seeds = th.randperm(num_test_seeds)
tseeds_per_proc = num_train_seeds // n_gpus
vseeds_per_proc = num_valid_seeds // n_gpus
tstseeds_per_proc = num_test_seeds // n_gpus
for proc_id in range(n_gpus):
proc_train_seeds = train_idx[proc_id * tseeds_per_proc :
# we have multi-gpu for training, evaluation and testing
# so split trian set, valid set and test set into num-of-gpu parts.
proc_train_seeds = train_seeds[proc_id * tseeds_per_proc :
(proc_id + 1) * tseeds_per_proc \
if (proc_id + 1) * tseeds_per_proc < num_train_seeds \
else num_train_seeds]
proc_valid_seeds = valid_seeds[proc_id * vseeds_per_proc :
(proc_id + 1) * vseeds_per_proc \
if (proc_id + 1) * vseeds_per_proc < num_valid_seeds \
else num_valid_seeds]
proc_test_seeds = test_seeds[proc_id * tstseeds_per_proc :
(proc_id + 1) * tstseeds_per_proc \
if (proc_id + 1) * tstseeds_per_proc < num_test_seeds \
else num_test_seeds]
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices,
(g, num_of_ntype, num_classes, num_rels, target_idx,
proc_train_seeds, val_idx, test_idx, labels)))
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels),
(proc_train_seeds, proc_valid_seeds, proc_test_seeds),
queue))
p.start()
procs.append(p)
for p in procs:
......@@ -436,7 +579,7 @@ def config():
help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel")
parser.add_argument("--fanout", type=int, default=4,
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling.")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
......@@ -445,6 +588,8 @@ def config():
fp.add_argument('--testing', dest='validation', action='store_false')
parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size. ")
parser.add_argument("--eval-batch-size", type=int, default=128,
help="Mini-batch size. ")
parser.add_argument("--num-workers", type=int, default=0,
help="Number of workers for dataloader.")
parser.add_argument("--low-mem", default=False, action='store_true',
......@@ -453,6 +598,12 @@ def config():
help="Whether store node embeddins in cpu")
parser.add_argument("--sparse-embedding", action='store_true',
help='Use sparse embedding for node embeddings.')
parser.add_argument('--node-feats', default=False, action='store_true',
help='Whether use node features')
parser.add_argument('--global-norm', default=False, action='store_true',
help='User global norm instead of per node type norm')
parser.add_argument('--layer-norm', default=False, action='store_true',
help='Use layer norm')
parser.set_defaults(validation=True)
args = parser.parse_args()
return args
......
......@@ -91,16 +91,15 @@ class RelGraphEmbedLayer(nn.Module):
for ntype in range(num_of_ntype):
if input_size[ntype] is not None:
loc = node_tids == ntype
input_emb_size = node_tids[loc].shape[0]
input_emb_size = input_size[ntype].shape[1]
embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(embed)
self.embeds[str(ntype)] = embed
self.node_embeds = th.nn.Embedding(node_tids.shape[0], self.embed_size, sparse=self.sparse_emb)
nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0)
def forward(self, node_ids, node_tids, features):
def forward(self, node_ids, node_tids, type_ids, features):
"""Forward computation
Parameters
----------
......@@ -118,12 +117,14 @@ class RelGraphEmbedLayer(nn.Module):
tensor
embeddings as the input of the next layer
"""
tsd_idx = node_ids < self.num_nodes
tsd_ids = node_ids[tsd_idx]
embeds = self.node_embeds(tsd_ids)
tsd_ids = node_ids.to(self.node_embeds.weight.device)
embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.dev_id)
for ntype in range(self.num_of_ntype):
if features[ntype] is not None:
loc = node_tids == ntype
embeds[loc] = features[ntype] @ self.embeds[str(ntype)]
embeds[loc] = features[ntype][type_ids[loc]].to(self.dev_id) @ self.embeds[str(ntype)].to(self.dev_id)
else:
loc = node_tids == ntype
embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.dev_id)
return embeds.to(self.dev_id)
return embeds
......@@ -61,6 +61,8 @@ class RelGraphConv(gluon.Block):
Default: False.
dropout : float, optional
Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
"""
def __init__(self,
in_feat,
......@@ -72,7 +74,8 @@ class RelGraphConv(gluon.Block):
activation=None,
self_loop=False,
low_mem=False,
dropout=0.0):
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -86,6 +89,7 @@ class RelGraphConv(gluon.Block):
self.self_loop = self_loop
assert low_mem is False, 'MXNet currently does not support low-memory implementation.'
assert layer_norm is False, 'MXNet currently does not support layer norm.'
if regularizer == "basis":
# add basis weights
......
......@@ -59,6 +59,8 @@ class RelGraphConv(nn.Module):
Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional
Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
"""
def __init__(self,
in_feat,
......@@ -70,7 +72,8 @@ class RelGraphConv(nn.Module):
activation=None,
self_loop=False,
low_mem=False,
dropout=0.0):
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -83,6 +86,7 @@ class RelGraphConv(nn.Module):
self.activation = activation
self.self_loop = self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm
if regularizer == "basis":
# add basis weights
......@@ -120,6 +124,10 @@ class RelGraphConv(nn.Module):
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(n_hidden, elementwise_affine=True)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
......@@ -219,6 +227,8 @@ class RelGraphConv(nn.Module):
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.dstdata['h']
if self.layer_norm:
node_repr = self.layer_norm_weight(node_repr)
if self.bias:
node_repr = node_repr + self.h_bias
if self.self_loop:
......
......@@ -59,6 +59,8 @@ class RelGraphConv(layers.Layer):
Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional
Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False
"""
def __init__(self,
......@@ -71,7 +73,8 @@ class RelGraphConv(layers.Layer):
activation=None,
self_loop=False,
low_mem=False,
dropout=0.0):
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -85,6 +88,8 @@ class RelGraphConv(layers.Layer):
self.self_loop = self_loop
self.low_mem = low_mem
assert layer_norm is False, 'TensorFlow currently does not support layer norm.'
xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment