"""Modeling Relational Data with Graph Convolutional Networks Paper: https://arxiv.org/abs/1703.06103 Reference Code: https://github.com/tkipf/relational-gcn """ import argparse import time import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset from model import EntityClassify_HeteroAPI def main(args): # load graph data if args.dataset == "aifb": dataset = AIFBDataset() elif args.dataset == "mutag": dataset = MUTAGDataset() elif args.dataset == "bgs": dataset = BGSDataset() elif args.dataset == "am": dataset = AMDataset() else: raise ValueError() g = dataset[0] category = dataset.predict_category num_classes = dataset.num_classes train_mask = g.nodes[category].data.pop("train_mask") test_mask = g.nodes[category].data.pop("test_mask") train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() labels = g.nodes[category].data.pop("labels") category_id = len(g.ntypes) for i, ntype in enumerate(g.ntypes): if ntype == category: category_id = i # split dataset into train, validate, test if args.validation: val_idx = train_idx[: len(train_idx) // 5] train_idx = train_idx[len(train_idx) // 5 :] else: val_idx = train_idx # check cuda use_cuda = args.gpu >= 0 and th.cuda.is_available() if use_cuda: th.cuda.set_device(args.gpu) g = g.to("cuda:%d" % args.gpu) labels = labels.cuda() train_idx = train_idx.cuda() test_idx = test_idx.cuda() # create model model = EntityClassify_HeteroAPI( g, args.n_hidden, num_classes, num_bases=args.n_bases, num_hidden_layers=args.n_layers - 2, dropout=args.dropout, use_self_loop=args.use_self_loop, ) if use_cuda: model.cuda() # optimizer optimizer = th.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.l2norm ) # training loop print("start training...") dur = [] model.train() for epoch in range(args.n_epochs): optimizer.zero_grad() t0 = time.time() logits = model()[category] loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss.backward() optimizer.step() t1 = time.time() dur.append(t1 - t0) train_acc = th.sum( logits[train_idx].argmax(dim=1) == labels[train_idx] ).item() / len(train_idx) val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) val_acc = th.sum( logits[val_idx].argmax(dim=1) == labels[val_idx] ).item() / len(val_idx) print( "Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format( epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur), ) ) print() if args.model_path is not None: th.save(model.state_dict(), args.model_path) model.eval() logits = model.forward()[category] test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_acc = th.sum( logits[test_idx].argmax(dim=1) == labels[test_idx] ).item() / len(test_idx) print( "Test Acc: {:.4f} | Test loss: {:.4f}".format( test_acc, test_loss.item() ) ) print() if __name__ == "__main__": parser = argparse.ArgumentParser(description="RGCN") parser.add_argument( "--dropout", type=float, default=0, help="dropout probability" ) parser.add_argument( "--n-hidden", type=int, default=16, help="number of hidden units" ) parser.add_argument("--gpu", type=int, default=-1, help="gpu") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") parser.add_argument( "--n-bases", type=int, default=-1, help="number of filter weight matrices, default: -1 [use all]", ) parser.add_argument( "--n-layers", type=int, default=2, help="number of propagation rounds" ) parser.add_argument( "-e", "--n-epochs", type=int, default=50, help="number of training epochs", ) parser.add_argument( "-d", "--dataset", type=str, required=True, help="dataset to use" ) parser.add_argument( "--model_path", type=str, default=None, help="path for save the model" ) parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef") parser.add_argument( "--use-self-loop", default=False, action="store_true", help="include self feature as a special relation", ) fp = parser.add_mutually_exclusive_group(required=False) fp.add_argument("--validation", dest="validation", action="store_true") fp.add_argument("--testing", dest="validation", action="store_false") parser.set_defaults(validation=True) args = parser.parse_args() print(args) main(args)