"tests/python/vscode:/vscode.git/clone" did not exist on "8204fe1912d95bac865797af98f01dafc2ba2b65"
Unverified Commit 77ae4d53 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Example] Fix Rgcn hetero nonedetermistic in graph and model initialization (#1244)

* Fix model save and load problem.

Add offline inference example

* Fix test

* Fix randomness problem between test runs

* Make rdf loader deterministic
parent eeeb52f4
......@@ -55,3 +55,30 @@ AM: accuracy 91.41% (DGL), 89.29% (paper)
```
python3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
### Offline Inferencing
Trained Model can be exported by providing '--model\_path <PATH>' parameter to entity\_classify.py. And then test\_classify.py can load the saved model and do the testing offline.
AIFB:
```
python3 entity_classify.py -d aifb --testing --gpu 0 --model_path "aifb.pt"
python3 test_classify.py -d aifb --gpu 0 --model_path "aifb.pt"
```
MUTAG:
```
python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --model_path "mutag.pt"
python3 test_classify.py -d mutag --n-bases 30 --gpu 0 --model_path "mutag.pt"
```
BGS:
```
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --model_path "bgs.pt"
python3 test_classify.py -d bgs --n-bases 40 --gpu 0 --model_path "bgs.pt"
```
AM:
```
python3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --model_path "am.pt"
python3 test_classify.py -d am --n-bases 40 --gpu 0 --model_path "am.pt"
```
\ No newline at end of file
......@@ -156,11 +156,11 @@ class RelGraphConvHeteroEmbed(nn.Module):
self.self_loop = self_loop
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterList()
self.embeds = nn.ParameterDict()
for srctype, etype, dsttype in g.canonical_etypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(srctype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds.append(embed)
self.embeds["{}-{}-{}".format(srctype, etype, dsttype)] = embed
# bias
if self.bias:
......@@ -189,7 +189,7 @@ class RelGraphConvHeteroEmbed(nn.Module):
g = self.g.local_var()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['embed-%d' % i] = self.embeds[i]
g.nodes[srctype].data['embed-%d' % i] = self.embeds["{}-{}-{}".format(srctype, etype, dsttype)]
funcs[(srctype, etype, dsttype)] = (fn.copy_u('embed-%d' % i, 'm'), fn.mean('m', 'h'))
g.multi_update_all(funcs, 'sum')
......@@ -220,6 +220,7 @@ class EntityClassify(nn.Module):
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
......@@ -285,7 +286,6 @@ def main(args):
labels = labels.cuda()
train_idx = train_idx.cuda()
test_idx = test_idx.cuda()
labels = labels.cuda()
# create model
model = EntityClassify(g,
......@@ -324,6 +324,8 @@ def main(args):
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur)))
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
model.eval()
logits = model.forward()[category_id]
......@@ -350,6 +352,8 @@ if __name__ == '__main__':
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
......
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
import torch as th
from functools import partial
import torch.nn.functional as F
from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from entity_classify import EntityClassify
def main(args):
# load graph data
if args.dataset == 'aifb':
dataset = AIFB()
elif args.dataset == 'mutag':
dataset = MUTAG()
elif args.dataset == 'bgs':
dataset = BGS()
elif args.dataset == 'am':
dataset = AM()
else:
raise ValueError()
g = dataset.graph
category = dataset.predict_category
num_classes = dataset.num_classes
test_idx = dataset.test_idx
labels = dataset.labels
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
category_id = i
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
labels = labels.cuda()
test_idx = test_idx.cuda()
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop)
# training loop
model.load_state_dict(th.load(args.model_path))
if use_cuda:
model.cuda()
print("start testing...")
model.eval()
logits = model.forward()[category_id]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str,
help='path of the model to load from')
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
args = parser.parse_args()
print(args)
main(args)
\ No newline at end of file
......@@ -142,7 +142,12 @@ class RDFGraphDataset:
dst = []
ntid = []
etid = []
for i, (sbj, pred, obj) in enumerate(raw_tuples):
sorted_tuples = []
for t in raw_tuples:
sorted_tuples.append(t)
sorted_tuples.sort()
for i, (sbj, pred, obj) in enumerate(sorted_tuples):
if i % self._print_every == 0:
print('Processed %d tuples, found %d valid tuples.' % (i, len(src)))
sbjent = self.parse_entity(sbj)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment