Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
from ogb.nodeproppred import DglNodePropPredDataset
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
def forward(self, sg, x): def forward(self, sg, x):
...@@ -26,37 +29,43 @@ class SAGE(nn.Module): ...@@ -26,37 +29,43 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
graph = dataset[0] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
graph = dataset[
0
] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata["feat"].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
num_partitions = 1000 num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler( sampler = dgl.dataloading.ClusterGCNSampler(
graph, num_partitions, graph,
prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask']) num_partitions,
prefetch_ndata=["feat", "label", "train_mask", "val_mask", "test_mask"],
)
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like # DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler. # partition IDs here), and a graph sampler.
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
graph, graph,
torch.arange(num_partitions).to('cuda'), torch.arange(num_partitions).to("cuda"),
sampler, sampler,
device='cuda', device="cuda",
batch_size=100, batch_size=100,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0, num_workers=0,
use_uva=True) use_uva=True,
)
durations = [] durations = []
for _ in range(10): for _ in range(10):
t0 = time.time() t0 = time.time()
model.train() model.train()
for it, sg in enumerate(dataloader): for it, sg in enumerate(dataloader):
x = sg.ndata['feat'] x = sg.ndata["feat"]
y = sg.ndata['label'] y = sg.ndata["label"]
m = sg.ndata['train_mask'].bool() m = sg.ndata["train_mask"].bool()
y_hat = model(sg, x) y_hat = model(sg, x)
loss = F.cross_entropy(y_hat[m], y[m]) loss = F.cross_entropy(y_hat[m], y[m])
opt.zero_grad() opt.zero_grad()
...@@ -65,7 +74,7 @@ for _ in range(10): ...@@ -65,7 +74,7 @@ for _ in range(10):
if it % 20 == 0: if it % 20 == 0:
acc = MF.accuracy(y_hat[m], y[m]) acc = MF.accuracy(y_hat[m], y[m])
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
...@@ -75,10 +84,10 @@ for _ in range(10): ...@@ -75,10 +84,10 @@ for _ in range(10):
val_preds, test_preds = [], [] val_preds, test_preds = [], []
val_labels, test_labels = [], [] val_labels, test_labels = [], []
for it, sg in enumerate(dataloader): for it, sg in enumerate(dataloader):
x = sg.ndata['feat'] x = sg.ndata["feat"]
y = sg.ndata['label'] y = sg.ndata["label"]
m_val = sg.ndata['val_mask'].bool() m_val = sg.ndata["val_mask"].bool()
m_test = sg.ndata['test_mask'].bool() m_test = sg.ndata["test_mask"].bool()
y_hat = model(sg, x) y_hat = model(sg, x)
val_preds.append(y_hat[m_val]) val_preds.append(y_hat[m_val])
val_labels.append(y[m_val]) val_labels.append(y[m_val])
...@@ -90,6 +99,6 @@ for _ in range(10): ...@@ -90,6 +99,6 @@ for _ in range(10):
test_labels = torch.cat(test_labels, 0) test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels) val_acc = MF.accuracy(val_preds, val_labels)
test_acc = MF.accuracy(test_preds, test_labels) test_acc = MF.accuracy(test_preds, test_labels)
print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item()) print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import dgl
from collections import defaultdict as ddict from collections import defaultdict as ddict
import numpy as np
import torch
from ordered_set import OrderedSet from ordered_set import OrderedSet
from torch.utils.data import DataLoader, Dataset
import dgl
class TrainDataset(Dataset): class TrainDataset(Dataset):
""" """
...@@ -18,6 +21,7 @@ class TrainDataset(Dataset): ...@@ -18,6 +21,7 @@ class TrainDataset(Dataset):
------- -------
A training Dataset class instance used by DataLoader A training Dataset class instance used by DataLoader
""" """
def __init__(self, triples, num_ent, lbl_smooth): def __init__(self, triples, num_ent, lbl_smooth):
self.triples = triples self.triples = triples
self.num_ent = num_ent self.num_ent = num_ent
...@@ -29,11 +33,13 @@ class TrainDataset(Dataset): ...@@ -29,11 +33,13 @@ class TrainDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ele = self.triples[idx] ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label']) triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
trp_label = self.get_label(label) trp_label = self.get_label(label)
#label smoothing # label smoothing
if self.lbl_smooth != 0.0: if self.lbl_smooth != 0.0:
trp_label = (1.0 - self.lbl_smooth) * trp_label + (1.0 / self.num_ent) trp_label = (1.0 - self.lbl_smooth) * trp_label + (
1.0 / self.num_ent
)
return triple, trp_label return triple, trp_label
...@@ -48,10 +54,10 @@ class TrainDataset(Dataset): ...@@ -48,10 +54,10 @@ class TrainDataset(Dataset):
trp_label = torch.stack(labels, dim=0) trp_label = torch.stack(labels, dim=0)
return triple, trp_label return triple, trp_label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label): def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32) y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label: for e2 in label:
y[e2] = 1.0 y[e2] = 1.0
return torch.FloatTensor(y) return torch.FloatTensor(y)
...@@ -68,6 +74,7 @@ class TestDataset(Dataset): ...@@ -68,6 +74,7 @@ class TestDataset(Dataset):
------- -------
An evaluation Dataset class instance used by DataLoader for model evaluation An evaluation Dataset class instance used by DataLoader for model evaluation
""" """
def __init__(self, triples, num_ent): def __init__(self, triples, num_ent):
self.triples = triples self.triples = triples
self.num_ent = num_ent self.num_ent = num_ent
...@@ -77,7 +84,7 @@ class TestDataset(Dataset): ...@@ -77,7 +84,7 @@ class TestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ele = self.triples[idx] ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label']) triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
label = self.get_label(label) label = self.get_label(label)
return triple, label return triple, label
...@@ -93,19 +100,18 @@ class TestDataset(Dataset): ...@@ -93,19 +100,18 @@ class TestDataset(Dataset):
label = torch.stack(labels, dim=0) label = torch.stack(labels, dim=0)
return triple, label return triple, label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label): def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32) y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label: for e2 in label:
y[e2] = 1.0 y[e2] = 1.0
return torch.FloatTensor(y) return torch.FloatTensor(y)
class Data(object): class Data(object):
def __init__(self, dataset, lbl_smooth, num_workers, batch_size): def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
""" """
Reading in raw triples and converts it into a standard format. Reading in raw triples and converts it into a standard format.
Parameters Parameters
---------- ----------
dataset: The name of the dataset dataset: The name of the dataset
...@@ -133,18 +139,23 @@ class Data(object): ...@@ -133,18 +139,23 @@ class Data(object):
self.num_workers = num_workers self.num_workers = num_workers
self.batch_size = batch_size self.batch_size = batch_size
#read in raw data and get mappings # read in raw data and get mappings
ent_set, rel_set = OrderedSet(), OrderedSet() ent_set, rel_set = OrderedSet(), OrderedSet()
for split in ['train', 'test', 'valid']: for split in ["train", "test", "valid"]:
for line in open('./{}/{}.txt'.format(self.dataset, split)): for line in open("./{}/{}.txt".format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t')) sub, rel, obj = map(str.lower, line.strip().split("\t"))
ent_set.add(sub) ent_set.add(sub)
rel_set.add(rel) rel_set.add(rel)
ent_set.add(obj) ent_set.add(obj)
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)}) self.rel2id.update(
{
rel + "_reverse": idx + len(self.rel2id)
for idx, rel in enumerate(rel_set)
}
)
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()} self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()} self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
...@@ -152,92 +163,121 @@ class Data(object): ...@@ -152,92 +163,121 @@ class Data(object):
self.num_ent = len(self.ent2id) self.num_ent = len(self.ent2id)
self.num_rel = len(self.rel2id) // 2 self.num_rel = len(self.rel2id) // 2
#read in ids of subjects, relations, and objects for train/test/valid # read in ids of subjects, relations, and objects for train/test/valid
self.data = ddict(list) #stores the triples self.data = ddict(list) # stores the triples
sr2o = ddict(set) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation) sr2o = ddict(
src=[] set
dst=[] ) # The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
src = []
dst = []
rels = [] rels = []
inver_src = [] inver_src = []
inver_dst = [] inver_dst = []
inver_rels = [] inver_rels = []
for split in ['train', 'test', 'valid']: for split in ["train", "test", "valid"]:
for line in open('./{}/{}.txt'.format(self.dataset, split)): for line in open("./{}/{}.txt".format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t')) sub, rel, obj = map(str.lower, line.strip().split("\t"))
sub_id, rel_id, obj_id = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] sub_id, rel_id, obj_id = (
self.ent2id[sub],
self.rel2id[rel],
self.ent2id[obj],
)
self.data[split].append((sub_id, rel_id, obj_id)) self.data[split].append((sub_id, rel_id, obj_id))
if split == 'train': if split == "train":
sr2o[(sub_id, rel_id)].add(obj_id) sr2o[(sub_id, rel_id)].add(obj_id)
sr2o[(obj_id, rel_id+self.num_rel)].add(sub_id) #append the reversed edges sr2o[(obj_id, rel_id + self.num_rel)].add(
sub_id
) # append the reversed edges
src.append(sub_id) src.append(sub_id)
dst.append(obj_id) dst.append(obj_id)
rels.append(rel_id) rels.append(rel_id)
inver_src.append(obj_id) inver_src.append(obj_id)
inver_dst.append(sub_id) inver_dst.append(sub_id)
inver_rels.append(rel_id+self.num_rel) inver_rels.append(rel_id + self.num_rel)
#construct dgl graph # construct dgl graph
src = src + inver_src src = src + inver_src
dst = dst + inver_dst dst = dst + inver_dst
rels = rels + inver_rels rels = rels + inver_rels
self.g = dgl.graph((src, dst), num_nodes=self.num_ent) self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
self.g.edata['etype'] = torch.Tensor(rels).long() self.g.edata["etype"] = torch.Tensor(rels).long()
#identify in and out edges # identify in and out edges
in_edges_mask = [True] * (self.g.num_edges()//2) + [False] * (self.g.num_edges()//2) in_edges_mask = [True] * (self.g.num_edges() // 2) + [False] * (
out_edges_mask = [False] * (self.g.num_edges()//2) + [True] * (self.g.num_edges()//2) self.g.num_edges() // 2
self.g.edata['in_edges_mask'] = torch.Tensor(in_edges_mask) )
self.g.edata['out_edges_mask'] = torch.Tensor(out_edges_mask) out_edges_mask = [False] * (self.g.num_edges() // 2) + [True] * (
self.g.num_edges() // 2
#Prepare train/valid/test data )
self.g.edata["in_edges_mask"] = torch.Tensor(in_edges_mask)
self.g.edata["out_edges_mask"] = torch.Tensor(out_edges_mask)
# Prepare train/valid/test data
self.data = dict(self.data) self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()} #store only the train data self.sr2o = {
k: list(v) for k, v in sr2o.items()
} # store only the train data
for split in ['test', 'valid']: for split in ["test", "valid"]:
for sub, rel, obj in self.data[split]: for sub, rel, obj in self.data[split]:
sr2o[(sub, rel)].add(obj) sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel+self.num_rel)].add(sub) sr2o[(obj, rel + self.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()} #store all the data self.sr2o_all = {
self.triples = ddict(list) k: list(v) for k, v in sr2o.items()
} # store all the data
self.triples = ddict(list)
for (sub, rel), obj in self.sr2o.items(): for (sub, rel), obj in self.sr2o.items():
self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)]}) self.triples["train"].append(
{"triple": (sub, rel, -1), "label": self.sr2o[(sub, rel)]}
)
for split in ['test', 'valid']: for split in ["test", "valid"]:
for sub, rel, obj in self.data[split]: for sub, rel, obj in self.data[split]:
rel_inv = rel + self.num_rel rel_inv = rel + self.num_rel
self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]}) self.triples["{}_{}".format(split, "tail")].append(
self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) {
"triple": (sub, rel, obj),
"label": self.sr2o_all[(sub, rel)],
}
)
self.triples["{}_{}".format(split, "head")].append(
{
"triple": (obj, rel_inv, sub),
"label": self.sr2o_all[(obj, rel_inv)],
}
)
self.triples = dict(self.triples) self.triples = dict(self.triples)
def get_train_data_loader(split, batch_size, shuffle=True): def get_train_data_loader(split, batch_size, shuffle=True):
return DataLoader( return DataLoader(
TrainDataset(self.triples[split], self.num_ent, self.lbl_smooth), TrainDataset(
batch_size = batch_size, self.triples[split], self.num_ent, self.lbl_smooth
shuffle = shuffle, ),
num_workers = max(0, self.num_workers), batch_size=batch_size,
collate_fn = TrainDataset.collate_fn shuffle=shuffle,
) num_workers=max(0, self.num_workers),
collate_fn=TrainDataset.collate_fn,
)
def get_test_data_loader(split, batch_size, shuffle=True): def get_test_data_loader(split, batch_size, shuffle=True):
return DataLoader( return DataLoader(
TestDataset(self.triples[split], self.num_ent), TestDataset(self.triples[split], self.num_ent),
batch_size = batch_size, batch_size=batch_size,
shuffle = shuffle, shuffle=shuffle,
num_workers = max(0, self.num_workers), num_workers=max(0, self.num_workers),
collate_fn = TestDataset.collate_fn collate_fn=TestDataset.collate_fn,
) )
#train/valid/test dataloaders # train/valid/test dataloaders
self.data_iter = { self.data_iter = {
'train': get_train_data_loader('train', self.batch_size), "train": get_train_data_loader("train", self.batch_size),
'valid_head': get_test_data_loader('valid_head', self.batch_size), "valid_head": get_test_data_loader("valid_head", self.batch_size),
'valid_tail': get_test_data_loader('valid_tail', self.batch_size), "valid_tail": get_test_data_loader("valid_tail", self.batch_size),
'test_head': get_test_data_loader('test_head', self.batch_size), "test_head": get_test_data_loader("test_head", self.batch_size),
'test_tail': get_test_data_loader('test_tail', self.batch_size), "test_tail": get_test_data_loader("test_tail", self.batch_size),
} }
\ No newline at end of file
import argparse import argparse
from time import time
import numpy as np
import torch as th import torch as th
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn import torch.optim as optim
from data_loader import Data
from models import CompGCN_ConvE
from utils import in_out_norm from utils import in_out_norm
from models import CompGCN_ConvE import dgl.function as fn
from data_loader import Data
import numpy as np
from time import time
#predict the tail for (head, rel, -1) or head for (-1, rel, tail) # predict the tail for (head, rel, -1) or head for (-1, rel, tail)
def predict(model, graph, device, data_iter, split='valid', mode='tail'): def predict(model, graph, device, data_iter, split="valid", mode="tail"):
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
results = {} results = {}
train_iter = iter(data_iter['{}_{}'.format(split, mode)]) train_iter = iter(data_iter["{}_{}".format(split, mode)])
for step, batch in enumerate(train_iter): for step, batch in enumerate(train_iter):
triple, label = batch[0].to(device), batch[1].to(device) triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label sub, rel, obj, label = (
triple[:, 0],
triple[:, 1],
triple[:, 2],
label,
)
pred = model(graph, sub, rel) pred = model(graph, sub, rel)
b_range = th.arange(pred.size()[0], device = device) b_range = th.arange(pred.size()[0], device=device)
target_pred = pred[b_range, obj] target_pred = pred[b_range, obj]
pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred) pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred pred[b_range, obj] = target_pred
#compute metrics # compute metrics
ranks = 1 + th.argsort(th.argsort(pred, dim=1, descending=True), dim =1, descending=False)[b_range, obj] ranks = (
1
+ th.argsort(
th.argsort(pred, dim=1, descending=True),
dim=1,
descending=False,
)[b_range, obj]
)
ranks = ranks.float() ranks = ranks.float()
results['count'] = th.numel(ranks) + results.get('count', 0.0) results["count"] = th.numel(ranks) + results.get("count", 0.0)
results['mr'] = th.sum(ranks).item() + results.get('mr', 0.0) results["mr"] = th.sum(ranks).item() + results.get("mr", 0.0)
results['mrr'] = th.sum(1.0/ranks).item() + results.get('mrr', 0.0) results["mrr"] = th.sum(1.0 / ranks).item() + results.get(
for k in [1,3,10]: "mrr", 0.0
results['hits@{}'.format(k)] = th.numel(ranks[ranks <= (k)]) + results.get('hits@{}'.format(k), 0.0) )
for k in [1, 3, 10]:
results["hits@{}".format(k)] = th.numel(
ranks[ranks <= (k)]
) + results.get("hits@{}".format(k), 0.0)
return results return results
#evaluation function, evaluate the head and tail prediction and then combine the results
def evaluate(model, graph, device, data_iter, split='valid'): # evaluation function, evaluate the head and tail prediction and then combine the results
#predict for head and tail def evaluate(model, graph, device, data_iter, split="valid"):
left_results = predict(model, graph, device, data_iter, split, mode='tail') # predict for head and tail
right_results = predict(model, graph, device, data_iter, split, mode='head') left_results = predict(model, graph, device, data_iter, split, mode="tail")
right_results = predict(model, graph, device, data_iter, split, mode="head")
results = {} results = {}
count = float(left_results['count']) count = float(left_results["count"])
#combine the head and tail prediction results # combine the head and tail prediction results
#Metrics: MRR, MR, and Hit@k # Metrics: MRR, MR, and Hit@k
results['left_mr'] = round(left_results['mr']/count, 5) results["left_mr"] = round(left_results["mr"] / count, 5)
results['left_mrr'] = round(left_results['mrr']/count, 5) results["left_mrr"] = round(left_results["mrr"] / count, 5)
results['right_mr'] = round(right_results['mr']/count, 5) results["right_mr"] = round(right_results["mr"] / count, 5)
results['right_mrr'] = round(right_results['mrr']/count, 5) results["right_mrr"] = round(right_results["mrr"] / count, 5)
results['mr'] = round((left_results['mr'] + right_results['mr']) /(2*count), 5) results["mr"] = round(
results['mrr'] = round((left_results['mrr'] + right_results['mrr']) /(2*count), 5) (left_results["mr"] + right_results["mr"]) / (2 * count), 5
for k in [1,3,10]: )
results['left_hits@{}'.format(k)] = round(left_results['hits@{}'.format(k)]/count, 5) results["mrr"] = round(
results['right_hits@{}'.format(k)] = round(right_results['hits@{}'.format(k)]/count, 5) (left_results["mrr"] + right_results["mrr"]) / (2 * count), 5
results['hits@{}'.format(k)] = round((left_results['hits@{}'.format(k)] + right_results['hits@{}'.format(k)])/(2*count), 5) )
return results for k in [1, 3, 10]:
results["left_hits@{}".format(k)] = round(
left_results["hits@{}".format(k)] / count, 5
)
results["right_hits@{}".format(k)] = round(
right_results["hits@{}".format(k)] / count, 5
)
results["hits@{}".format(k)] = round(
(
left_results["hits@{}".format(k)]
+ right_results["hits@{}".format(k)]
)
/ (2 * count),
5,
)
return results
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
#construct graph, split in/out edges and prepare train/validation/test data_loader # construct graph, split in/out edges and prepare train/validation/test data_loader
data = Data(args.dataset, args.lbl_smooth, args.num_workers, args.batch_size) data = Data(
data_iter = data.data_iter #train/validation/test data_loader args.dataset, args.lbl_smooth, args.num_workers, args.batch_size
)
data_iter = data.data_iter # train/validation/test data_loader
graph = data.g.to(device) graph = data.g.to(device)
num_rel = th.max(graph.edata['etype']).item() + 1 num_rel = th.max(graph.edata["etype"]).item() + 1
#Compute in/out edge norms and store in edata # Compute in/out edge norms and store in edata
graph = in_out_norm(graph) graph = in_out_norm(graph)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
compgcn_model=CompGCN_ConvE(num_bases=args.num_bases, compgcn_model = CompGCN_ConvE(
num_rel=num_rel, num_bases=args.num_bases,
num_ent=graph.num_nodes(), num_rel=num_rel,
in_dim=args.init_dim, num_ent=graph.num_nodes(),
layer_size=args.layer_size, in_dim=args.init_dim,
comp_fn=args.opn, layer_size=args.layer_size,
batchnorm=True, comp_fn=args.opn,
dropout=args.dropout, batchnorm=True,
layer_dropout=args.layer_dropout, dropout=args.dropout,
num_filt=args.num_filt, layer_dropout=args.layer_dropout,
hid_drop=args.hid_drop, num_filt=args.num_filt,
feat_drop=args.feat_drop, hid_drop=args.hid_drop,
ker_sz=args.ker_sz, feat_drop=args.feat_drop,
k_w=args.k_w, ker_sz=args.ker_sz,
k_h=args.k_h k_w=args.k_w,
) k_h=args.k_h,
)
compgcn_model = compgcn_model.to(device) compgcn_model = compgcn_model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
loss_fn = th.nn.BCELoss() loss_fn = th.nn.BCELoss()
optimizer = optim.Adam(compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2) optimizer = optim.Adam(
compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
best_mrr = 0.0 best_mrr = 0.0
kill_cnt = 0 kill_cnt = 0
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
# Training and validation using a full graph # Training and validation using a full graph
compgcn_model.train() compgcn_model.train()
train_loss=[] train_loss = []
t0 = time() t0 = time()
for step, batch in enumerate(data_iter['train']): for step, batch in enumerate(data_iter["train"]):
triple, label = batch[0].to(device), batch[1].to(device) triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label sub, rel, obj, label = (
triple[:, 0],
triple[:, 1],
triple[:, 2],
label,
)
logits = compgcn_model(graph, sub, rel) logits = compgcn_model(graph, sub, rel)
# compute loss # compute loss
tr_loss = loss_fn(logits, label) tr_loss = loss_fn(logits, label)
train_loss.append(tr_loss.item()) train_loss.append(tr_loss.item())
...@@ -129,66 +170,192 @@ def main(args): ...@@ -129,66 +170,192 @@ def main(args):
train_loss = np.sum(train_loss) train_loss = np.sum(train_loss)
t1 = time() t1 = time()
val_results = evaluate(compgcn_model, graph, device, data_iter, split='valid') val_results = evaluate(
compgcn_model, graph, device, data_iter, split="valid"
)
t2 = time() t2 = time()
#validate # validate
if val_results['mrr']>best_mrr: if val_results["mrr"] > best_mrr:
best_mrr = val_results['mrr'] best_mrr = val_results["mrr"]
best_epoch = epoch best_epoch = epoch
th.save(compgcn_model.state_dict(), 'comp_link'+'_'+args.dataset) th.save(
compgcn_model.state_dict(), "comp_link" + "_" + args.dataset
)
kill_cnt = 0 kill_cnt = 0
print("saving model...") print("saving model...")
else: else:
kill_cnt += 1 kill_cnt += 1
if kill_cnt > 100: if kill_cnt > 100:
print('early stop.') print("early stop.")
break break
print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\ print(
.format(epoch, train_loss, val_results['mrr'], t1-t0, t2-t1)) "In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}".format(
epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1
#test use the best model )
)
# test use the best model
compgcn_model.eval() compgcn_model.eval()
compgcn_model.load_state_dict(th.load('comp_link'+'_'+args.dataset)) compgcn_model.load_state_dict(th.load("comp_link" + "_" + args.dataset))
test_results = evaluate(compgcn_model, graph, device, data_iter, split='test') test_results = evaluate(
print("Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n"\ compgcn_model, graph, device, data_iter, split="test"
.format(test_results['mrr'], test_results['mr'], test_results['hits@10'], test_results['hits@3'], test_results['hits@1'])) )
print(
"Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n".format(
if __name__ == '__main__': test_results["mrr"],
parser = argparse.ArgumentParser(description='Parser For Arguments', formatter_class=argparse.ArgumentDefaultsHelpFormatter) test_results["mr"],
test_results["hits@10"],
parser.add_argument('--data', dest='dataset', default='FB15k-237', help='Dataset to use, default: FB15k-237') test_results["hits@3"],
parser.add_argument('--model', dest='model', default='compgcn', help='Model Name') test_results["hits@1"],
parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction') )
parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN') )
parser.add_argument('--batch', dest='batch_size', default=1024, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0') if __name__ == "__main__":
parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs') parser = argparse.ArgumentParser(
parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer') description="Parser For Arguments",
parser.add_argument('--lr', type=float, default=0.001, help='Starting Learning Rate') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--lbl_smooth', dest='lbl_smooth', type=float, default=0.1, help='Label Smoothing') )
parser.add_argument('--num_workers', type=int, default=10, help='Number of processes to construct batches')
parser.add_argument('--seed', dest='seed', default=41504, type=int, help='Seed for randomization') parser.add_argument(
"--data",
parser.add_argument('--num_bases', dest='num_bases', default=-1, type=int, help='Number of basis relation vectors to use') dest="dataset",
parser.add_argument('--init_dim', dest='init_dim', default=100, type=int, help='Initial dimension size for entities and relations') default="FB15k-237",
parser.add_argument('--layer_size', nargs='?', default='[200]', help='List of output size for each compGCN layer') help="Dataset to use, default: FB15k-237",
parser.add_argument('--gcn_drop', dest='dropout', default=0.1, type=float, help='Dropout to use in GCN Layer') )
parser.add_argument('--layer_dropout', nargs='?', default='[0.3]', help='List of dropout value after each compGCN layer') parser.add_argument(
"--model", dest="model", default="compgcn", help="Model Name"
)
parser.add_argument(
"--score_func",
dest="score_func",
default="conve",
help="Score Function for Link prediction",
)
parser.add_argument(
"--opn",
dest="opn",
default="ccorr",
help="Composition Operation to be used in CompGCN",
)
parser.add_argument(
"--batch", dest="batch_size", default=1024, type=int, help="Batch size"
)
parser.add_argument(
"--gpu",
type=int,
default="0",
help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
)
parser.add_argument(
"--epoch",
dest="max_epochs",
type=int,
default=500,
help="Number of epochs",
)
parser.add_argument(
"--l2", type=float, default=0.0, help="L2 Regularization for Optimizer"
)
parser.add_argument(
"--lr", type=float, default=0.001, help="Starting Learning Rate"
)
parser.add_argument(
"--lbl_smooth",
dest="lbl_smooth",
type=float,
default=0.1,
help="Label Smoothing",
)
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of processes to construct batches",
)
parser.add_argument(
"--seed",
dest="seed",
default=41504,
type=int,
help="Seed for randomization",
)
parser.add_argument(
"--num_bases",
dest="num_bases",
default=-1,
type=int,
help="Number of basis relation vectors to use",
)
parser.add_argument(
"--init_dim",
dest="init_dim",
default=100,
type=int,
help="Initial dimension size for entities and relations",
)
parser.add_argument(
"--layer_size",
nargs="?",
default="[200]",
help="List of output size for each compGCN layer",
)
parser.add_argument(
"--gcn_drop",
dest="dropout",
default=0.1,
type=float,
help="Dropout to use in GCN Layer",
)
parser.add_argument(
"--layer_dropout",
nargs="?",
default="[0.3]",
help="List of dropout value after each compGCN layer",
)
# ConvE specific hyperparameters # ConvE specific hyperparameters
parser.add_argument('--hid_drop', dest='hid_drop', default=0.3, type=float, help='ConvE: Hidden dropout') parser.add_argument(
parser.add_argument('--feat_drop', dest='feat_drop', default=0.3, type=float, help='ConvE: Feature Dropout') "--hid_drop",
parser.add_argument('--k_w', dest='k_w', default=10, type=int, help='ConvE: k_w') dest="hid_drop",
parser.add_argument('--k_h', dest='k_h', default=20, type=int, help='ConvE: k_h') default=0.3,
parser.add_argument('--num_filt', dest='num_filt', default=200, type=int, help='ConvE: Number of filters in convolution') type=float,
parser.add_argument('--ker_sz', dest='ker_sz', default=7, type=int, help='ConvE: Kernel size to use') help="ConvE: Hidden dropout",
)
parser.add_argument(
"--feat_drop",
dest="feat_drop",
default=0.3,
type=float,
help="ConvE: Feature Dropout",
)
parser.add_argument(
"--k_w", dest="k_w", default=10, type=int, help="ConvE: k_w"
)
parser.add_argument(
"--k_h", dest="k_h", default=20, type=int, help="ConvE: k_h"
)
parser.add_argument(
"--num_filt",
dest="num_filt",
default=200,
type=int,
help="ConvE: Number of filters in convolution",
)
parser.add_argument(
"--ker_sz",
dest="ker_sz",
default=7,
type=int,
help="ConvE: Kernel size to use",
)
args = parser.parse_args() args = parser.parse_args()
np.random.seed(args.seed) np.random.seed(args.seed)
th.manual_seed(args.seed) th.manual_seed(args.seed)
...@@ -198,4 +365,3 @@ if __name__ == '__main__': ...@@ -198,4 +365,3 @@ if __name__ == '__main__':
args.layer_dropout = eval(args.layer_dropout) args.layer_dropout = eval(args.layer_dropout)
main(args) main(args)
import torch as th import torch as th
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils import ccorr
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from utils import ccorr
class CompGraphConv(nn.Module): class CompGraphConv(nn.Module):
"""One layer of CompGCN.""" """One layer of CompGCN."""
def __init__(self, def __init__(
in_dim, self, in_dim, out_dim, comp_fn="sub", batchnorm=True, dropout=0.1
out_dim, ):
comp_fn='sub',
batchnorm=True,
dropout=0.1):
super(CompGraphConv, self).__init__() super(CompGraphConv, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
...@@ -40,63 +36,74 @@ class CompGraphConv(nn.Module): ...@@ -40,63 +36,74 @@ class CompGraphConv(nn.Module):
# define relation transform layer # define relation transform layer
self.W_R = nn.Linear(self.in_dim, self.out_dim) self.W_R = nn.Linear(self.in_dim, self.out_dim)
#self loop embedding # self loop embedding
self.loop_rel = nn.Parameter(th.Tensor(1, self.in_dim)) self.loop_rel = nn.Parameter(th.Tensor(1, self.in_dim))
nn.init.xavier_normal_(self.loop_rel) nn.init.xavier_normal_(self.loop_rel)
def forward(self, g, n_in_feats, r_feats): def forward(self, g, n_in_feats, r_feats):
with g.local_scope(): with g.local_scope():
# Assign values to source nodes. In a homogeneous graph, this is equal to # Assign values to source nodes. In a homogeneous graph, this is equal to
# assigning them to all nodes. # assigning them to all nodes.
g.srcdata['h'] = n_in_feats g.srcdata["h"] = n_in_feats
#append loop_rel embedding to r_feats # append loop_rel embedding to r_feats
r_feats = th.cat((r_feats, self.loop_rel), 0) r_feats = th.cat((r_feats, self.loop_rel), 0)
# Assign features to all edges with the corresponding relation embeddings # Assign features to all edges with the corresponding relation embeddings
g.edata['h'] = r_feats[g.edata['etype']] * g.edata['norm'] g.edata["h"] = r_feats[g.edata["etype"]] * g.edata["norm"]
# Compute composition function in 4 steps # Compute composition function in 4 steps
# Step 1: compute composition by edge in the edge direction, and store results in edges. # Step 1: compute composition by edge in the edge direction, and store results in edges.
if self.comp_fn == 'sub': if self.comp_fn == "sub":
g.apply_edges(fn.u_sub_e('h', 'h', out='comp_h')) g.apply_edges(fn.u_sub_e("h", "h", out="comp_h"))
elif self.comp_fn == 'mul': elif self.comp_fn == "mul":
g.apply_edges(fn.u_mul_e('h', 'h', out='comp_h')) g.apply_edges(fn.u_mul_e("h", "h", out="comp_h"))
elif self.comp_fn == 'ccorr': elif self.comp_fn == "ccorr":
g.apply_edges(lambda edges: {'comp_h': ccorr(edges.src['h'], edges.data['h'])}) g.apply_edges(
lambda edges: {
"comp_h": ccorr(edges.src["h"], edges.data["h"])
}
)
else: else:
raise Exception('Only supports sub, mul, and ccorr') raise Exception("Only supports sub, mul, and ccorr")
# Step 2: use extracted edge direction to compute in and out edges # Step 2: use extracted edge direction to compute in and out edges
comp_h = g.edata['comp_h'] comp_h = g.edata["comp_h"]
in_edges_idx = th.nonzero(g.edata['in_edges_mask'], as_tuple=False).squeeze() in_edges_idx = th.nonzero(
out_edges_idx = th.nonzero(g.edata['out_edges_mask'], as_tuple=False).squeeze() g.edata["in_edges_mask"], as_tuple=False
).squeeze()
out_edges_idx = th.nonzero(
g.edata["out_edges_mask"], as_tuple=False
).squeeze()
comp_h_O = self.W_O(comp_h[out_edges_idx]) comp_h_O = self.W_O(comp_h[out_edges_idx])
comp_h_I = self.W_I(comp_h[in_edges_idx]) comp_h_I = self.W_I(comp_h[in_edges_idx])
new_comp_h = th.zeros(comp_h.shape[0], self.out_dim).to(comp_h.device) new_comp_h = th.zeros(comp_h.shape[0], self.out_dim).to(
comp_h.device
)
new_comp_h[out_edges_idx] = comp_h_O new_comp_h[out_edges_idx] = comp_h_O
new_comp_h[in_edges_idx] = comp_h_I new_comp_h[in_edges_idx] = comp_h_I
g.edata['new_comp_h'] = new_comp_h g.edata["new_comp_h"] = new_comp_h
# Step 3: sum comp results to both src and dst nodes # Step 3: sum comp results to both src and dst nodes
g.update_all(fn.copy_e('new_comp_h', 'm'), fn.sum('m', 'comp_edge')) g.update_all(fn.copy_e("new_comp_h", "m"), fn.sum("m", "comp_edge"))
# Step 4: add results of self-loop # Step 4: add results of self-loop
if self.comp_fn == 'sub': if self.comp_fn == "sub":
comp_h_s = n_in_feats - r_feats[-1] comp_h_s = n_in_feats - r_feats[-1]
elif self.comp_fn == 'mul': elif self.comp_fn == "mul":
comp_h_s = n_in_feats * r_feats[-1] comp_h_s = n_in_feats * r_feats[-1]
elif self.comp_fn == 'ccorr': elif self.comp_fn == "ccorr":
comp_h_s = ccorr(n_in_feats, r_feats[-1]) comp_h_s = ccorr(n_in_feats, r_feats[-1])
else: else:
raise Exception('Only supports sub, mul, and ccorr') raise Exception("Only supports sub, mul, and ccorr")
# Sum all of the comp results as output of nodes and dropout # Sum all of the comp results as output of nodes and dropout
n_out_feats = (self.W_S(comp_h_s) + self.dropout(g.ndata['comp_edge'])) * (1/3) n_out_feats = (
self.W_S(comp_h_s) + self.dropout(g.ndata["comp_edge"])
) * (1 / 3)
# Compute relation output # Compute relation output
r_out_feats = self.W_R(r_feats) r_out_feats = self.W_R(r_feats)
...@@ -113,16 +120,18 @@ class CompGraphConv(nn.Module): ...@@ -113,16 +120,18 @@ class CompGraphConv(nn.Module):
class CompGCN(nn.Module): class CompGCN(nn.Module):
def __init__(self, def __init__(
num_bases, self,
num_rel, num_bases,
num_ent, num_rel,
in_dim=100, num_ent,
layer_size=[200], in_dim=100,
comp_fn='sub', layer_size=[200],
batchnorm=True, comp_fn="sub",
dropout=0.1, batchnorm=True,
layer_dropout=[0.3]): dropout=0.1,
layer_dropout=[0.3],
):
super(CompGCN, self).__init__() super(CompGCN, self).__init__()
self.num_bases = num_bases self.num_bases = num_bases
...@@ -136,17 +145,29 @@ class CompGCN(nn.Module): ...@@ -136,17 +145,29 @@ class CompGCN(nn.Module):
self.layer_dropout = layer_dropout self.layer_dropout = layer_dropout
self.num_layer = len(layer_size) self.num_layer = len(layer_size)
#CompGCN layers # CompGCN layers
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append( self.layers.append(
CompGraphConv(self.in_dim, self.layer_size[0], comp_fn = self.comp_fn, batchnorm=self.batchnorm, dropout=self.dropout) CompGraphConv(
self.in_dim,
self.layer_size[0],
comp_fn=self.comp_fn,
batchnorm=self.batchnorm,
dropout=self.dropout,
)
) )
for i in range(self.num_layer-1): for i in range(self.num_layer - 1):
self.layers.append( self.layers.append(
CompGraphConv(self.layer_size[i], self.layer_size[i+1], comp_fn = self.comp_fn, batchnorm=self.batchnorm, dropout=self.dropout) CompGraphConv(
self.layer_size[i],
self.layer_size[i + 1],
comp_fn=self.comp_fn,
batchnorm=self.batchnorm,
dropout=self.dropout,
)
) )
#Initial relation embeddings # Initial relation embeddings
if self.num_bases > 0: if self.num_bases > 0:
self.basis = nn.Parameter(th.Tensor(self.num_bases, self.in_dim)) self.basis = nn.Parameter(th.Tensor(self.num_bases, self.in_dim))
self.weights = nn.Parameter(th.Tensor(self.num_rel, self.num_bases)) self.weights = nn.Parameter(th.Tensor(self.num_rel, self.num_bases))
...@@ -156,20 +177,17 @@ class CompGCN(nn.Module): ...@@ -156,20 +177,17 @@ class CompGCN(nn.Module):
self.rel_embds = nn.Parameter(th.Tensor(self.num_rel, self.in_dim)) self.rel_embds = nn.Parameter(th.Tensor(self.num_rel, self.in_dim))
nn.init.xavier_normal_(self.rel_embds) nn.init.xavier_normal_(self.rel_embds)
#Node embeddings # Node embeddings
self.n_embds = nn.Parameter(th.Tensor(self.num_ent, self.in_dim)) self.n_embds = nn.Parameter(th.Tensor(self.num_ent, self.in_dim))
nn.init.xavier_normal_(self.n_embds) nn.init.xavier_normal_(self.n_embds)
#Dropout after compGCN layers # Dropout after compGCN layers
self.dropouts = nn.ModuleList() self.dropouts = nn.ModuleList()
for i in range(self.num_layer): for i in range(self.num_layer):
self.dropouts.append( self.dropouts.append(nn.Dropout(self.layer_dropout[i]))
nn.Dropout(self.layer_dropout[i])
)
def forward(self, graph): def forward(self, graph):
#node and relation features # node and relation features
n_feats = self.n_embds n_feats = self.n_embds
if self.num_bases > 0: if self.num_bases > 0:
r_embds = th.mm(self.weights, self.basis) r_embds = th.mm(self.weights, self.basis)
...@@ -183,73 +201,94 @@ class CompGCN(nn.Module): ...@@ -183,73 +201,94 @@ class CompGCN(nn.Module):
return n_feats, r_feats return n_feats, r_feats
#Use convE as the score function
# Use convE as the score function
class CompGCN_ConvE(nn.Module): class CompGCN_ConvE(nn.Module):
def __init__(self, def __init__(
num_bases, self,
num_rel, num_bases,
num_ent, num_rel,
in_dim, num_ent,
layer_size, in_dim,
comp_fn='sub', layer_size,
batchnorm=True, comp_fn="sub",
dropout=0.1, batchnorm=True,
layer_dropout=[0.3], dropout=0.1,
num_filt=200, layer_dropout=[0.3],
hid_drop=0.3, num_filt=200,
feat_drop=0.3, hid_drop=0.3,
ker_sz=5, feat_drop=0.3,
k_w=5, ker_sz=5,
k_h=5 k_w=5,
): k_h=5,
):
super(CompGCN_ConvE, self).__init__() super(CompGCN_ConvE, self).__init__()
self.embed_dim = layer_size[-1] self.embed_dim = layer_size[-1]
self.hid_drop=hid_drop self.hid_drop = hid_drop
self.feat_drop=feat_drop self.feat_drop = feat_drop
self.ker_sz=ker_sz self.ker_sz = ker_sz
self.k_w=k_w self.k_w = k_w
self.k_h=k_h self.k_h = k_h
self.num_filt=num_filt self.num_filt = num_filt
#compGCN model to get sub/rel embs # compGCN model to get sub/rel embs
self.compGCN_Model = CompGCN(num_bases, num_rel, num_ent, in_dim, layer_size, comp_fn, batchnorm, dropout, layer_dropout) self.compGCN_Model = CompGCN(
num_bases,
#batchnorms to the combined (sub+rel) emb num_rel,
num_ent,
in_dim,
layer_size,
comp_fn,
batchnorm,
dropout,
layer_dropout,
)
# batchnorms to the combined (sub+rel) emb
self.bn0 = th.nn.BatchNorm2d(1) self.bn0 = th.nn.BatchNorm2d(1)
self.bn1 = th.nn.BatchNorm2d(self.num_filt) self.bn1 = th.nn.BatchNorm2d(self.num_filt)
self.bn2 = th.nn.BatchNorm1d(self.embed_dim) self.bn2 = th.nn.BatchNorm1d(self.embed_dim)
#dropouts and conv module to the combined (sub+rel) emb # dropouts and conv module to the combined (sub+rel) emb
self.hidden_drop = th.nn.Dropout(self.hid_drop) self.hidden_drop = th.nn.Dropout(self.hid_drop)
self.feature_drop = th.nn.Dropout(self.feat_drop) self.feature_drop = th.nn.Dropout(self.feat_drop)
self.m_conv1 = th.nn.Conv2d(1, out_channels=self.num_filt, kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=False) self.m_conv1 = th.nn.Conv2d(
1,
out_channels=self.num_filt,
kernel_size=(self.ker_sz, self.ker_sz),
stride=1,
padding=0,
bias=False,
)
flat_sz_h = int(2 * self.k_w) - self.ker_sz + 1 flat_sz_h = int(2 * self.k_w) - self.ker_sz + 1
flat_sz_w = self.k_h - self.ker_sz + 1 flat_sz_w = self.k_h - self.ker_sz + 1
self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
self.fc = th.nn.Linear(self.flat_sz, self.embed_dim) self.fc = th.nn.Linear(self.flat_sz, self.embed_dim)
#bias to the score # bias to the score
self.bias = nn.Parameter(th.zeros(num_ent)) self.bias = nn.Parameter(th.zeros(num_ent))
#combine entity embeddings and relation embeddings # combine entity embeddings and relation embeddings
def concat(self, e1_embed, rel_embed): def concat(self, e1_embed, rel_embed):
e1_embed = e1_embed.view(-1, 1, self.embed_dim) e1_embed = e1_embed.view(-1, 1, self.embed_dim)
rel_embed = rel_embed.view(-1, 1, self.embed_dim) rel_embed = rel_embed.view(-1, 1, self.embed_dim)
stack_inp = th.cat([e1_embed, rel_embed], 1) stack_inp = th.cat([e1_embed, rel_embed], 1)
stack_inp = th.transpose(stack_inp, 2, 1).reshape((-1, 1, 2 * self.k_w, self.k_h)) stack_inp = th.transpose(stack_inp, 2, 1).reshape(
(-1, 1, 2 * self.k_w, self.k_h)
)
return stack_inp return stack_inp
def forward(self, graph, sub, rel): def forward(self, graph, sub, rel):
#get sub_emb and rel_emb via compGCN # get sub_emb and rel_emb via compGCN
n_feats, r_feats = self.compGCN_Model(graph) n_feats, r_feats = self.compGCN_Model(graph)
sub_emb = n_feats[sub, :] sub_emb = n_feats[sub, :]
rel_emb = r_feats[rel, :] rel_emb = r_feats[rel, :]
#combine the sub_emb and rel_emb # combine the sub_emb and rel_emb
stk_inp = self.concat(sub_emb, rel_emb) stk_inp = self.concat(sub_emb, rel_emb)
#use convE to score the combined emb # use convE to score the combined emb
x = self.bn0(stk_inp) x = self.bn0(stk_inp)
x = self.m_conv1(x) x = self.m_conv1(x)
x = self.bn1(x) x = self.bn1(x)
...@@ -260,11 +299,9 @@ class CompGCN_ConvE(nn.Module): ...@@ -260,11 +299,9 @@ class CompGCN_ConvE(nn.Module):
x = self.hidden_drop(x) x = self.hidden_drop(x)
x = self.bn2(x) x = self.bn2(x)
x = F.relu(x) x = F.relu(x)
#compute score # compute score
x = th.mm(x, n_feats.transpose(1,0)) x = th.mm(x, n_feats.transpose(1, 0))
#add in bias # add in bias
x += self.bias.expand_as(x) x += self.bias.expand_as(x)
score = th.sigmoid(x) score = th.sigmoid(x)
return score return score
...@@ -3,55 +3,65 @@ ...@@ -3,55 +3,65 @@
# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation. # It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.
import torch as th import torch as th
import dgl import dgl
def com_mult(a, b): def com_mult(a, b):
r1, i1 = a[..., 0], a[..., 1] r1, i1 = a[..., 0], a[..., 1]
r2, i2 = b[..., 0], b[..., 1] r2, i2 = b[..., 0], b[..., 1]
return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1) return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
def conj(a): def conj(a):
a[..., 1] = -a[..., 1] a[..., 1] = -a[..., 1]
return a return a
def ccorr(a, b): def ccorr(a, b):
""" """
Compute circular correlation of two tensors. Compute circular correlation of two tensors.
Parameters Parameters
---------- ----------
a: Tensor, 1D or 2D a: Tensor, 1D or 2D
b: Tensor, 1D or 2D b: Tensor, 1D or 2D
Notes Notes
----- -----
Input a and b should have the same dimensions. And this operation supports broadcasting. Input a and b should have the same dimensions. And this operation supports broadcasting.
Returns Returns
------- -------
Tensor, having the same dimension as the input a. Tensor, having the same dimension as the input a.
""" """
return th.fft.irfftn(th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)) return th.fft.irfftn(
th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)
#identify in/out edges, compute edge norm for each and store in edata )
# identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph): def in_out_norm(graph):
src, dst, EID = graph.edges(form='all') src, dst, EID = graph.edges(form="all")
graph.edata['norm'] = th.ones(EID.shape[0]).to(graph.device) graph.edata["norm"] = th.ones(EID.shape[0]).to(graph.device)
in_edges_idx = th.nonzero(graph.edata['in_edges_mask'], as_tuple=False).squeeze() in_edges_idx = th.nonzero(
out_edges_idx = th.nonzero(graph.edata['out_edges_mask'], as_tuple=False).squeeze() graph.edata["in_edges_mask"], as_tuple=False
).squeeze()
for idx in [in_edges_idx, out_edges_idx]: out_edges_idx = th.nonzero(
u, v = src[idx], dst[idx] graph.edata["out_edges_mask"], as_tuple=False
deg = th.zeros(graph.num_nodes()).to(graph.device) ).squeeze()
n_idx, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
deg[n_idx]=count.float() for idx in [in_edges_idx, out_edges_idx]:
deg_inv = deg.pow(-0.5) # D^{-0.5} u, v = src[idx], dst[idx]
deg_inv[deg_inv == float('inf')] = 0 deg = th.zeros(graph.num_nodes()).to(graph.device)
norm = deg_inv[u] * deg_inv[v] n_idx, inverse_index, count = th.unique(
graph.edata['norm'][idx] = norm v, return_inverse=True, return_counts=True
graph.edata['norm'] = graph.edata['norm'].unsqueeze(1) )
deg[n_idx] = count.float()
return graph deg_inv = deg.pow(-0.5) # D^{-0.5}
deg_inv[deg_inv == float("inf")] = 0
norm = deg_inv[u] * deg_inv[v]
graph.edata["norm"][idx] = norm
graph.edata["norm"] = graph.edata["norm"].unsqueeze(1)
return graph
import argparse import argparse
import copy import copy
import os import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import dgl from model import MLP, CorrectAndSmooth, MLPLinear
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from model import MLP, MLPLinear, CorrectAndSmooth
import dgl
def evaluate(y_pred, y_true, idx, evaluator): def evaluate(y_pred, y_true, idx, evaluator):
return evaluator.eval({ return evaluator.eval({"y_true": y_true[idx], "y_pred": y_pred[idx]})["acc"]
'y_true': y_true[idx],
'y_pred': y_pred[idx]
})['acc']
def main(): def main():
# check cuda # check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' device = (
f"cuda:{args.gpu}"
if torch.cuda.is_available() and args.gpu >= 0
else "cpu"
)
# load data # load data
dataset = DglNodePropPredDataset(name=args.dataset) dataset = DglNodePropPredDataset(name=args.dataset)
evaluator = Evaluator(name=args.dataset) evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
g, labels = dataset[0] # graph: DGLGraph object, label: torch tensor of shape (num_nodes, num_tasks) g, labels = dataset[
0
if args.dataset == 'ogbn-arxiv': ] # graph: DGLGraph object, label: torch tensor of shape (num_nodes, num_tasks)
if args.dataset == "ogbn-arxiv":
g = dgl.to_bidirected(g, copy_ndata=True) g = dgl.to_bidirected(g, copy_ndata=True)
feat = g.ndata['feat'] feat = g.ndata["feat"]
feat = (feat - feat.mean(0)) / feat.std(0) feat = (feat - feat.mean(0)) / feat.std(0)
g.ndata['feat'] = feat g.ndata["feat"] = feat
g = g.to(device) g = g.to(device)
feats = g.ndata['feat'] feats = g.ndata["feat"]
labels = labels.to(device) labels = labels.to(device)
# load masks for train / validation / test # load masks for train / validation / test
...@@ -44,21 +49,25 @@ def main(): ...@@ -44,21 +49,25 @@ def main():
n_features = feats.size()[-1] n_features = feats.size()[-1]
n_classes = dataset.num_classes n_classes = dataset.num_classes
# load model # load model
if args.model == 'mlp': if args.model == "mlp":
model = MLP(n_features, args.hid_dim, n_classes, args.num_layers, args.dropout) model = MLP(
elif args.model == 'linear': n_features, args.hid_dim, n_classes, args.num_layers, args.dropout
)
elif args.model == "linear":
model = MLPLinear(n_features, n_classes) model = MLPLinear(n_features, n_classes)
else: else:
raise NotImplementedError(f'Model {args.model} is not supported.') raise NotImplementedError(f"Model {args.model} is not supported.")
model = model.to(device) model = model.to(device)
print(f'Model parameters: {sum(p.numel() for p in model.parameters())}') print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
if args.pretrain: if args.pretrain:
print('---------- Before ----------') print("---------- Before ----------")
model.load_state_dict(torch.load(f'base/{args.dataset}-{args.model}.pt')) model.load_state_dict(
torch.load(f"base/{args.dataset}-{args.model}.pt")
)
model.eval() model.eval()
y_soft = model(feats).exp() y_soft = model(feats).exp()
...@@ -66,24 +75,26 @@ def main(): ...@@ -66,24 +75,26 @@ def main():
y_pred = y_soft.argmax(dim=-1, keepdim=True) y_pred = y_soft.argmax(dim=-1, keepdim=True)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator) valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
test_acc = evaluate(y_pred, labels, test_idx, evaluator) test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}') print(f"Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}")
print('---------- Correct & Smoothing ----------') print("---------- Correct & Smoothing ----------")
cs = CorrectAndSmooth(num_correction_layers=args.num_correction_layers, cs = CorrectAndSmooth(
correction_alpha=args.correction_alpha, num_correction_layers=args.num_correction_layers,
correction_adj=args.correction_adj, correction_alpha=args.correction_alpha,
num_smoothing_layers=args.num_smoothing_layers, correction_adj=args.correction_adj,
smoothing_alpha=args.smoothing_alpha, num_smoothing_layers=args.num_smoothing_layers,
smoothing_adj=args.smoothing_adj, smoothing_alpha=args.smoothing_alpha,
autoscale=args.autoscale, smoothing_adj=args.smoothing_adj,
scale=args.scale) autoscale=args.autoscale,
scale=args.scale,
)
y_soft = cs.correct(g, y_soft, labels[train_idx], train_idx) y_soft = cs.correct(g, y_soft, labels[train_idx], train_idx)
y_soft = cs.smooth(g, y_soft, labels[train_idx], train_idx) y_soft = cs.smooth(g, y_soft, labels[train_idx], train_idx)
y_pred = y_soft.argmax(dim=-1, keepdim=True) y_pred = y_soft.argmax(dim=-1, keepdim=True)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator) valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
test_acc = evaluate(y_pred, labels, test_idx, evaluator) test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}') print(f"Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}")
else: else:
opt = optim.Adam(model.parameters(), lr=args.lr) opt = optim.Adam(model.parameters(), lr=args.lr)
...@@ -91,79 +102,94 @@ def main(): ...@@ -91,79 +102,94 @@ def main():
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
# training # training
print('---------- Training ----------') print("---------- Training ----------")
for i in range(args.epochs): for i in range(args.epochs):
model.train() model.train()
opt.zero_grad() opt.zero_grad()
logits = model(feats) logits = model(feats)
train_loss = F.nll_loss(logits[train_idx], labels.squeeze(1)[train_idx]) train_loss = F.nll_loss(
logits[train_idx], labels.squeeze(1)[train_idx]
)
train_loss.backward() train_loss.backward()
opt.step() opt.step()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(feats) logits = model(feats)
y_pred = logits.argmax(dim=-1, keepdim=True) y_pred = logits.argmax(dim=-1, keepdim=True)
train_acc = evaluate(y_pred, labels, train_idx, evaluator) train_acc = evaluate(y_pred, labels, train_idx, evaluator)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator) valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
print(f'Epoch {i} | Train loss: {train_loss.item():.4f} | Train acc: {train_acc:.4f} | Valid acc {valid_acc:.4f}') print(
f"Epoch {i} | Train loss: {train_loss.item():.4f} | Train acc: {train_acc:.4f} | Valid acc {valid_acc:.4f}"
)
if valid_acc > best_acc: if valid_acc > best_acc:
best_acc = valid_acc best_acc = valid_acc
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
# testing & saving model # testing & saving model
print('---------- Testing ----------') print("---------- Testing ----------")
best_model.eval() best_model.eval()
logits = best_model(feats) logits = best_model(feats)
y_pred = logits.argmax(dim=-1, keepdim=True) y_pred = logits.argmax(dim=-1, keepdim=True)
test_acc = evaluate(y_pred, labels, test_idx, evaluator) test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Test acc: {test_acc:.4f}') print(f"Test acc: {test_acc:.4f}")
if not os.path.exists('base'): if not os.path.exists("base"):
os.makedirs('base') os.makedirs("base")
torch.save(best_model.state_dict(), f'base/{args.dataset}-{args.model}.pt') torch.save(
best_model.state_dict(), f"base/{args.dataset}-{args.model}.pt"
)
if __name__ == '__main__': if __name__ == "__main__":
""" """
Correct & Smoothing Hyperparameters Correct & Smoothing Hyperparameters
""" """
parser = argparse.ArgumentParser(description='Base predictor(C&S)') parser = argparse.ArgumentParser(description="Base predictor(C&S)")
# Dataset # Dataset
parser.add_argument('--gpu', type=int, default=0, help='-1 for cpu') parser.add_argument("--gpu", type=int, default=0, help="-1 for cpu")
parser.add_argument('--dataset', type=str, default='ogbn-arxiv', choices=['ogbn-arxiv', 'ogbn-products']) parser.add_argument(
"--dataset",
type=str,
default="ogbn-arxiv",
choices=["ogbn-arxiv", "ogbn-products"],
)
# Base predictor # Base predictor
parser.add_argument('--model', type=str, default='mlp', choices=['mlp', 'linear']) parser.add_argument(
parser.add_argument('--num-layers', type=int, default=3) "--model", type=str, default="mlp", choices=["mlp", "linear"]
parser.add_argument('--hid-dim', type=int, default=256) )
parser.add_argument('--dropout', type=float, default=0.4) parser.add_argument("--num-layers", type=int, default=3)
parser.add_argument('--lr', type=float, default=0.01) parser.add_argument("--hid-dim", type=int, default=256)
parser.add_argument('--epochs', type=int, default=300) parser.add_argument("--dropout", type=float, default=0.4)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--epochs", type=int, default=300)
# extra options for gat # extra options for gat
parser.add_argument('--n-heads', type=int, default=3) parser.add_argument("--n-heads", type=int, default=3)
parser.add_argument('--attn_drop', type=float, default=0.05) parser.add_argument("--attn_drop", type=float, default=0.05)
# C & S # C & S
parser.add_argument('--pretrain', action='store_true', help='Whether to perform C & S') parser.add_argument(
parser.add_argument('--num-correction-layers', type=int, default=50) "--pretrain", action="store_true", help="Whether to perform C & S"
parser.add_argument('--correction-alpha', type=float, default=0.979) )
parser.add_argument('--correction-adj', type=str, default='DAD') parser.add_argument("--num-correction-layers", type=int, default=50)
parser.add_argument('--num-smoothing-layers', type=int, default=50) parser.add_argument("--correction-alpha", type=float, default=0.979)
parser.add_argument('--smoothing-alpha', type=float, default=0.756) parser.add_argument("--correction-adj", type=str, default="DAD")
parser.add_argument('--smoothing-adj', type=str, default='DAD') parser.add_argument("--num-smoothing-layers", type=int, default=50)
parser.add_argument('--autoscale', action='store_true') parser.add_argument("--smoothing-alpha", type=float, default=0.756)
parser.add_argument('--scale', type=float, default=20.) parser.add_argument("--smoothing-adj", type=str, default="DAD")
parser.add_argument("--autoscale", action="store_true")
parser.add_argument("--scale", type=float, default=20.0)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
...@@ -9,7 +10,7 @@ class MLPLinear(nn.Module): ...@@ -9,7 +10,7 @@ class MLPLinear(nn.Module):
super(MLPLinear, self).__init__() super(MLPLinear, self).__init__()
self.linear = nn.Linear(in_dim, out_dim) self.linear = nn.Linear(in_dim, out_dim)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.linear.reset_parameters() self.linear.reset_parameters()
...@@ -18,7 +19,7 @@ class MLPLinear(nn.Module): ...@@ -18,7 +19,7 @@ class MLPLinear(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.): def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0):
super(MLP, self).__init__() super(MLP, self).__init__()
assert num_layers >= 2 assert num_layers >= 2
...@@ -30,7 +31,7 @@ class MLP(nn.Module): ...@@ -30,7 +31,7 @@ class MLP(nn.Module):
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.linears.append(nn.Linear(hid_dim, hid_dim)) self.linears.append(nn.Linear(hid_dim, hid_dim))
self.bns.append(nn.BatchNorm1d(hid_dim)) self.bns.append(nn.BatchNorm1d(hid_dim))
self.linears.append(nn.Linear(hid_dim, out_dim)) self.linears.append(nn.Linear(hid_dim, out_dim))
self.dropout = dropout self.dropout = dropout
self.reset_parameters() self.reset_parameters()
...@@ -75,42 +76,49 @@ class LabelPropagation(nn.Module): ...@@ -75,42 +76,49 @@ class LabelPropagation(nn.Module):
'DA': D^-1 * A 'DA': D^-1 * A
'AD': A * D^-1 'AD': A * D^-1
""" """
def __init__(self, num_layers, alpha, adj='DAD'):
def __init__(self, num_layers, alpha, adj="DAD"):
super(LabelPropagation, self).__init__() super(LabelPropagation, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.alpha = alpha self.alpha = alpha
self.adj = adj self.adj = adj
@torch.no_grad() @torch.no_grad()
def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0., 1.)): def forward(
self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)
):
with g.local_scope(): with g.local_scope():
if labels.dtype == torch.long: if labels.dtype == torch.long:
labels = F.one_hot(labels.view(-1)).to(torch.float32) labels = F.one_hot(labels.view(-1)).to(torch.float32)
y = labels y = labels
if mask is not None: if mask is not None:
y = torch.zeros_like(labels) y = torch.zeros_like(labels)
y[mask] = labels[mask] y[mask] = labels[mask]
last = (1 - self.alpha) * y last = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1) degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5 if self.adj == 'DAD' else -1).to(labels.device).unsqueeze(1) norm = (
torch.pow(degs, -0.5 if self.adj == "DAD" else -1)
.to(labels.device)
.unsqueeze(1)
)
for _ in range(self.num_layers): for _ in range(self.num_layers):
# Assume the graphs to be undirected # Assume the graphs to be undirected
if self.adj in ['DAD', 'AD']: if self.adj in ["DAD", "AD"]:
y = norm * y y = norm * y
g.ndata['h'] = y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = self.alpha * g.ndata.pop('h')
if self.adj in ['DAD', 'DA']: g.ndata["h"] = y
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
y = self.alpha * g.ndata.pop("h")
if self.adj in ["DAD", "DA"]:
y = y * norm y = y * norm
y = post_step(last + y) y = post_step(last + y)
return y return y
...@@ -144,41 +152,50 @@ class CorrectAndSmooth(nn.Module): ...@@ -144,41 +152,50 @@ class CorrectAndSmooth(nn.Module):
scale: float, optional scale: float, optional
The scaling factor :math:`\sigma`, in case :obj:`autoscale = False`. Default is 1. The scaling factor :math:`\sigma`, in case :obj:`autoscale = False`. Default is 1.
""" """
def __init__(self,
num_correction_layers, def __init__(
correction_alpha, self,
correction_adj, num_correction_layers,
num_smoothing_layers, correction_alpha,
smoothing_alpha, correction_adj,
smoothing_adj, num_smoothing_layers,
autoscale=True, smoothing_alpha,
scale=1.): smoothing_adj,
autoscale=True,
scale=1.0,
):
super(CorrectAndSmooth, self).__init__() super(CorrectAndSmooth, self).__init__()
self.autoscale = autoscale self.autoscale = autoscale
self.scale = scale self.scale = scale
self.prop1 = LabelPropagation(num_correction_layers, self.prop1 = LabelPropagation(
correction_alpha, num_correction_layers, correction_alpha, correction_adj
correction_adj) )
self.prop2 = LabelPropagation(num_smoothing_layers, self.prop2 = LabelPropagation(
smoothing_alpha, num_smoothing_layers, smoothing_alpha, smoothing_adj
smoothing_adj) )
def correct(self, g, y_soft, y_true, mask): def correct(self, g, y_soft, y_true, mask):
with g.local_scope(): with g.local_scope():
assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2 assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) numel = (
int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
)
assert y_true.size(0) == numel assert y_true.size(0) == numel
if y_true.dtype == torch.long: if y_true.dtype == torch.long:
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(y_soft.dtype) y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(
y_soft.dtype
)
error = torch.zeros_like(y_soft) error = torch.zeros_like(y_soft)
error[mask] = y_true - y_soft[mask] error[mask] = y_true - y_soft[mask]
if self.autoscale: if self.autoscale:
smoothed_error = self.prop1(g, error, post_step=lambda x: x.clamp_(-1., 1.)) smoothed_error = self.prop1(
g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)
)
sigma = error[mask].abs().sum() / numel sigma = error[mask].abs().sum() / numel
scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)
scale[scale.isinf() | (scale > 1000)] = 1.0 scale[scale.isinf() | (scale > 1000)] = 1.0
...@@ -187,10 +204,11 @@ class CorrectAndSmooth(nn.Module): ...@@ -187,10 +204,11 @@ class CorrectAndSmooth(nn.Module):
result[result.isnan()] = y_soft[result.isnan()] result[result.isnan()] = y_soft[result.isnan()]
return result return result
else: else:
def fix_input(x): def fix_input(x):
x[mask] = error[mask] x[mask] = error[mask]
return x return x
smoothed_error = self.prop1(g, error, post_step=fix_input) smoothed_error = self.prop1(g, error, post_step=fix_input)
result = y_soft + self.scale * smoothed_error result = y_soft + self.scale * smoothed_error
...@@ -199,11 +217,15 @@ class CorrectAndSmooth(nn.Module): ...@@ -199,11 +217,15 @@ class CorrectAndSmooth(nn.Module):
def smooth(self, g, y_soft, y_true, mask): def smooth(self, g, y_soft, y_true, mask):
with g.local_scope(): with g.local_scope():
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) numel = (
int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
)
assert y_true.size(0) == numel assert y_true.size(0) == numel
if y_true.dtype == torch.long: if y_true.dtype == torch.long:
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(y_soft.dtype) y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(
y_soft.dtype
)
y_soft[mask] = y_true y_soft[mask] = y_true
return self.prop2(g, y_soft) return self.prop2(g, y_soft)
import argparse import argparse
import numpy as np
import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import dgl.function as fn
from torch.nn import functional as F from torch.nn import functional as F
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
import numpy as np
import torch
from tqdm import trange from tqdm import trange
from utils import generate_random_seeds, set_random_state, evaluate from utils import evaluate, generate_random_seeds, set_random_state
import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class DAGNNConv(nn.Module): class DAGNNConv(nn.Module):
def __init__(self, def __init__(self, in_dim, k):
in_dim,
k):
super(DAGNNConv, self).__init__() super(DAGNNConv, self).__init__()
self.s = Parameter(torch.FloatTensor(in_dim, 1)) self.s = Parameter(torch.FloatTensor(in_dim, 1))
...@@ -22,7 +22,7 @@ class DAGNNConv(nn.Module): ...@@ -22,7 +22,7 @@ class DAGNNConv(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('sigmoid') gain = nn.init.calculate_gain("sigmoid")
nn.init.xavier_uniform_(self.s, gain=gain) nn.init.xavier_uniform_(self.s, gain=gain)
def forward(self, graph, feats): def forward(self, graph, feats):
...@@ -36,10 +36,9 @@ class DAGNNConv(nn.Module): ...@@ -36,10 +36,9 @@ class DAGNNConv(nn.Module):
for _ in range(self.k): for _ in range(self.k):
feats = feats * norm feats = feats * norm
graph.ndata['h'] = feats graph.ndata["h"] = feats
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
fn.sum('m', 'h')) feats = graph.ndata["h"]
feats = graph.ndata['h']
feats = feats * norm feats = feats * norm
results.append(feats) results.append(feats)
...@@ -52,12 +51,7 @@ class DAGNNConv(nn.Module): ...@@ -52,12 +51,7 @@ class DAGNNConv(nn.Module):
class MLPLayer(nn.Module): class MLPLayer(nn.Module):
def __init__(self, def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):
in_dim,
out_dim,
bias=True,
activation=None,
dropout=0):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias) self.linear = nn.Linear(in_dim, out_dim, bias=bias)
...@@ -66,9 +60,9 @@ class MLPLayer(nn.Module): ...@@ -66,9 +60,9 @@ class MLPLayer(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
gain = 1. gain = 1.0
if self.activation is F.relu: if self.activation is F.relu:
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.linear.weight, gain=gain) nn.init.xavier_uniform_(self.linear.weight, gain=gain)
if self.linear.bias is not None: if self.linear.bias is not None:
nn.init.zeros_(self.linear.bias) nn.init.zeros_(self.linear.bias)
...@@ -84,20 +78,36 @@ class MLPLayer(nn.Module): ...@@ -84,20 +78,36 @@ class MLPLayer(nn.Module):
class DAGNN(nn.Module): class DAGNN(nn.Module):
def __init__(self, def __init__(
k, self,
in_dim, k,
hid_dim, in_dim,
out_dim, hid_dim,
bias=True, out_dim,
activation=F.relu, bias=True,
dropout=0, ): activation=F.relu,
dropout=0,
):
super(DAGNN, self).__init__() super(DAGNN, self).__init__()
self.mlp = nn.ModuleList() self.mlp = nn.ModuleList()
self.mlp.append(MLPLayer(in_dim=in_dim, out_dim=hid_dim, bias=bias, self.mlp.append(
activation=activation, dropout=dropout)) MLPLayer(
self.mlp.append(MLPLayer(in_dim=hid_dim, out_dim=out_dim, bias=bias, in_dim=in_dim,
activation=None, dropout=dropout)) out_dim=hid_dim,
bias=bias,
activation=activation,
dropout=dropout,
)
)
self.mlp.append(
MLPLayer(
in_dim=hid_dim,
out_dim=out_dim,
bias=bias,
activation=None,
dropout=dropout,
)
)
self.dagnn = DAGNNConv(in_dim=out_dim, k=k) self.dagnn = DAGNNConv(in_dim=out_dim, k=k)
def forward(self, graph, feats): def forward(self, graph, feats):
...@@ -110,38 +120,38 @@ class DAGNN(nn.Module): ...@@ -110,38 +120,38 @@ class DAGNN(nn.Module):
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset # Load from DGL dataset
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed': elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0] graph = dataset[0]
graph = graph.add_self_loop() graph = graph.add_self_loop()
# check cuda # check cuda
if args.gpu >= 0 and torch.cuda.is_available(): if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
# retrieve the number of classes # retrieve the number of classes
n_classes = dataset.num_classes n_classes = dataset.num_classes
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long() labels = graph.ndata.pop("label").to(device).long()
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1] n_features = feats.shape[-1]
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata.pop('train_mask') train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop('val_mask') val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop('test_mask') test_mask = graph.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device) val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
...@@ -150,22 +160,26 @@ def main(args): ...@@ -150,22 +160,26 @@ def main(args):
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = DAGNN(k=args.k, model = DAGNN(
in_dim=n_features, k=args.k,
hid_dim=args.hid_dim, in_dim=n_features,
out_dim=n_classes, hid_dim=args.hid_dim,
dropout=args.dropout) out_dim=n_classes,
dropout=args.dropout,
)
model = model.to(device) model = model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
loss_fn = F.cross_entropy loss_fn = F.cross_entropy
opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb) opt = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.lamb
)
# Step 4: training epochs =============================================================== # # Step 4: training epochs =============================================================== #
loss = float('inf') loss = float("inf")
best_acc = 0 best_acc = 0
no_improvement = 0 no_improvement = 0
epochs = trange(args.epochs, desc='Accuracy & Loss') epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs: for _ in epochs:
model.train() model.train()
...@@ -180,17 +194,28 @@ def main(args): ...@@ -180,17 +194,28 @@ def main(args):
train_loss.backward() train_loss.backward()
opt.step() opt.step()
train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc = evaluate(model, graph, feats, labels, (
(train_idx, val_idx, test_idx)) train_loss,
train_acc,
valid_loss,
valid_acc,
test_loss,
test_acc,
) = evaluate(
model, graph, feats, labels, (train_idx, val_idx, test_idx)
)
# Print out performance # Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format( epochs.set_description(
train_acc, train_loss.item(), valid_acc, valid_loss.item())) "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_loss > loss: if valid_loss > loss:
no_improvement += 1 no_improvement += 1
if no_improvement == args.early_stopping: if no_improvement == args.early_stopping:
print('Early stop.') print("Early stop.")
break break
else: else:
no_improvement = 0 no_improvement = 0
...@@ -203,23 +228,42 @@ def main(args): ...@@ -203,23 +228,42 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
""" """
DAGNN Model Hyperparameters DAGNN Model Hyperparameters
""" """
parser = argparse.ArgumentParser(description='DAGNN') parser = argparse.ArgumentParser(description="DAGNN")
# data source params # data source params
parser.add_argument('--dataset', type=str, default='Cora', choices=["Cora", "Citeseer", "Pubmed"], help='Name of dataset.') parser.add_argument(
"--dataset",
type=str,
default="Cora",
choices=["Cora", "Citeseer", "Pubmed"],
help="Name of dataset.",
)
# cuda params # cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.') parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params # training params
parser.add_argument('--runs', type=int, default=1, help='Training runs.') parser.add_argument("--runs", type=int, default=1, help="Training runs.")
parser.add_argument('--epochs', type=int, default=1500, help='Training epochs.') parser.add_argument(
parser.add_argument('--early-stopping', type=int, default=100, help='Patient epochs to wait before early stopping.') "--epochs", type=int, default=1500, help="Training epochs."
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') )
parser.add_argument('--lamb', type=float, default=0.005, help='L2 reg.') parser.add_argument(
"--early-stopping",
type=int,
default=100,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--lamb", type=float, default=0.005, help="L2 reg.")
# model params # model params
parser.add_argument('--k', type=int, default=12, help='Number of propagation layers.') parser.add_argument(
parser.add_argument("--hid-dim", type=int, default=64, help='Hidden layer dimensionalities.') "--k", type=int, default=12, help="Number of propagation layers."
parser.add_argument('--dropout', type=float, default=0.8, help='dropout') )
parser.add_argument(
"--hid-dim", type=int, default=64, help="Hidden layer dimensionalities."
)
parser.add_argument("--dropout", type=float, default=0.8, help="dropout")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -235,6 +279,6 @@ if __name__ == "__main__": ...@@ -235,6 +279,6 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists, axis=0), decimals=4) mean = np.around(np.mean(acc_lists, axis=0), decimals=4)
std = np.around(np.std(acc_lists, axis=0), decimals=4) std = np.around(np.std(acc_lists, axis=0), decimals=4)
print('Total acc: ', acc_lists) print("Total acc: ", acc_lists)
print('mean', mean) print("mean", mean)
print('std', std) print("std", std)
\ No newline at end of file
import numpy as np
import random import random
from torch.nn import functional as F
import numpy as np
import torch import torch
from torch.nn import functional as F
def evaluate(model, graph, feats, labels, idxs): def evaluate(model, graph, feats, labels, idxs):
...@@ -11,7 +12,9 @@ def evaluate(model, graph, feats, labels, idxs): ...@@ -11,7 +12,9 @@ def evaluate(model, graph, feats, labels, idxs):
results = () results = ()
for idx in idxs: for idx in idxs:
loss = F.cross_entropy(logits[idx], labels[idx]) loss = F.cross_entropy(logits[idx], labels[idx])
acc = torch.sum(logits[idx].argmax(dim=1) == labels[idx]).item() / len(idx) acc = torch.sum(
logits[idx].argmax(dim=1) == labels[idx]
).item() / len(idx)
results += (loss, acc) results += (loss, acc)
return results return results
...@@ -27,4 +30,4 @@ def set_random_state(seed): ...@@ -27,4 +30,4 @@ def set_random_state(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn from modules import MLP, MessageNorm
from ogb.graphproppred.mol_encoder import BondEncoder from ogb.graphproppred.mol_encoder import BondEncoder
import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
from modules import MLP, MessageNorm
class GENConv(nn.Module): class GENConv(nn.Module):
r""" r"""
Description Description
----------- -----------
Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>" Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
...@@ -40,20 +40,23 @@ class GENConv(nn.Module): ...@@ -40,20 +40,23 @@ class GENConv(nn.Module):
eps: float eps: float
A small positive constant in message construction function. Default is 1e-7. A small positive constant in message construction function. Default is 1e-7.
""" """
def __init__(self,
in_dim, def __init__(
out_dim, self,
aggregator='softmax', in_dim,
beta=1.0, out_dim,
learn_beta=False, aggregator="softmax",
p=1.0, beta=1.0,
learn_p=False, learn_beta=False,
msg_norm=False, p=1.0,
learn_msg_scale=False, learn_p=False,
mlp_layers=1, msg_norm=False,
eps=1e-7): learn_msg_scale=False,
mlp_layers=1,
eps=1e-7,
):
super(GENConv, self).__init__() super(GENConv, self).__init__()
self.aggr = aggregator self.aggr = aggregator
self.eps = eps self.eps = eps
...@@ -65,38 +68,52 @@ class GENConv(nn.Module): ...@@ -65,38 +68,52 @@ class GENConv(nn.Module):
self.mlp = MLP(channels) self.mlp = MLP(channels)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
self.beta = nn.Parameter(torch.Tensor([beta]), requires_grad=True) if learn_beta and self.aggr == 'softmax' else beta self.beta = (
self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p nn.Parameter(torch.Tensor([beta]), requires_grad=True)
if learn_beta and self.aggr == "softmax"
else beta
)
self.p = (
nn.Parameter(torch.Tensor([p]), requires_grad=True)
if learn_p
else p
)
self.edge_encoder = BondEncoder(in_dim) self.edge_encoder = BondEncoder(in_dim)
def forward(self, g, node_feats, edge_feats): def forward(self, g, node_feats, edge_feats):
with g.local_scope(): with g.local_scope():
# Node and edge feature size need to match. # Node and edge feature size need to match.
g.ndata['h'] = node_feats g.ndata["h"] = node_feats
g.edata['h'] = self.edge_encoder(edge_feats) g.edata["h"] = self.edge_encoder(edge_feats)
g.apply_edges(fn.u_add_e('h', 'h', 'm')) g.apply_edges(fn.u_add_e("h", "h", "m"))
if self.aggr == 'softmax': if self.aggr == "softmax":
g.edata['m'] = F.relu(g.edata['m']) + self.eps g.edata["m"] = F.relu(g.edata["m"]) + self.eps
g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta) g.edata["a"] = edge_softmax(g, g.edata["m"] * self.beta)
g.update_all(lambda edge: {'x': edge.data['m'] * edge.data['a']}, g.update_all(
fn.sum('x', 'm')) lambda edge: {"x": edge.data["m"] * edge.data["a"]},
fn.sum("x", "m"),
elif self.aggr == 'power': )
elif self.aggr == "power":
minv, maxv = 1e-7, 1e1 minv, maxv = 1e-7, 1e1
torch.clamp_(g.edata['m'], minv, maxv) torch.clamp_(g.edata["m"], minv, maxv)
g.update_all(lambda edge: {'x': torch.pow(edge.data['m'], self.p)}, g.update_all(
fn.mean('x', 'm')) lambda edge: {"x": torch.pow(edge.data["m"], self.p)},
torch.clamp_(g.ndata['m'], minv, maxv) fn.mean("x", "m"),
g.ndata['m'] = torch.pow(g.ndata['m'], self.p) )
torch.clamp_(g.ndata["m"], minv, maxv)
g.ndata["m"] = torch.pow(g.ndata["m"], self.p)
else: else:
raise NotImplementedError(f'Aggregator {self.aggr} is not supported.') raise NotImplementedError(
f"Aggregator {self.aggr} is not supported."
)
if self.msg_norm is not None: if self.msg_norm is not None:
g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m']) g.ndata["m"] = self.msg_norm(node_feats, g.ndata["m"])
feats = node_feats + g.ndata['m'] feats = node_feats + g.ndata["m"]
return self.mlp(feats) return self.mlp(feats)
import argparse import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import copy import copy
import time import time
from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl import torch
from torch.utils.data import DataLoader import torch.nn as nn
from ogb.graphproppred import Evaluator import torch.optim as optim
from models import DeeperGCN from models import DeeperGCN
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator, collate_dgl
from torch.utils.data import DataLoader
def train(model, device, data_loader, opt, loss_fn): def train(model, device, data_loader, opt, loss_fn):
model.train() model.train()
train_loss = [] train_loss = []
for g, labels in data_loader: for g, labels in data_loader:
g = g.to(device) g = g.to(device)
labels = labels.to(torch.float32).to(device) labels = labels.to(torch.float32).to(device)
logits = model(g, g.edata['feat'], g.ndata['feat']) logits = model(g, g.edata["feat"], g.ndata["feat"])
loss = loss_fn(logits, labels) loss = loss_fn(logits, labels)
train_loss.append(loss.item()) train_loss.append(loss.item())
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
...@@ -36,57 +35,66 @@ def test(model, device, data_loader, evaluator): ...@@ -36,57 +35,66 @@ def test(model, device, data_loader, evaluator):
for g, labels in data_loader: for g, labels in data_loader:
g = g.to(device) g = g.to(device)
logits = model(g, g.edata['feat'], g.ndata['feat']) logits = model(g, g.edata["feat"], g.ndata["feat"])
y_true.append(labels.detach().cpu()) y_true.append(labels.detach().cpu())
y_pred.append(logits.detach().cpu()) y_pred.append(logits.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy() y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy() y_pred = torch.cat(y_pred, dim=0).numpy()
return evaluator.eval({ return evaluator.eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]
'y_true': y_true,
'y_pred': y_pred
})['rocauc']
def main(): def main():
# check cuda # check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu' device = (
f"cuda:{args.gpu}"
if args.gpu >= 0 and torch.cuda.is_available()
else "cpu"
)
# load ogb dataset & evaluator # load ogb dataset & evaluator
dataset = DglGraphPropPredDataset(name='ogbg-molhiv') dataset = DglGraphPropPredDataset(name="ogbg-molhiv")
evaluator = Evaluator(name='ogbg-molhiv') evaluator = Evaluator(name="ogbg-molhiv")
g, _ = dataset[0] g, _ = dataset[0]
node_feat_dim = g.ndata['feat'].size()[-1] node_feat_dim = g.ndata["feat"].size()[-1]
edge_feat_dim = g.edata['feat'].size()[-1] edge_feat_dim = g.edata["feat"].size()[-1]
n_classes = dataset.num_tasks n_classes = dataset.num_tasks
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]], train_loader = DataLoader(
batch_size=args.batch_size, dataset[split_idx["train"]],
shuffle=True, batch_size=args.batch_size,
collate_fn=collate_dgl) shuffle=True,
valid_loader = DataLoader(dataset[split_idx["valid"]], collate_fn=collate_dgl,
batch_size=args.batch_size, )
shuffle=False, valid_loader = DataLoader(
collate_fn=collate_dgl) dataset[split_idx["valid"]],
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size,
batch_size=args.batch_size, shuffle=False,
shuffle=False, collate_fn=collate_dgl,
collate_fn=collate_dgl) )
test_loader = DataLoader(
dataset[split_idx["test"]],
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_dgl,
)
# load model # load model
model = DeeperGCN(node_feat_dim=node_feat_dim, model = DeeperGCN(
edge_feat_dim=edge_feat_dim, node_feat_dim=node_feat_dim,
hid_dim=args.hid_dim, edge_feat_dim=edge_feat_dim,
out_dim=n_classes, hid_dim=args.hid_dim,
num_layers=args.num_layers, out_dim=n_classes,
dropout=args.dropout, num_layers=args.num_layers,
learn_beta=args.learn_beta).to(device) dropout=args.dropout,
learn_beta=args.learn_beta,
).to(device)
print(model) print(model)
opt = optim.Adam(model.parameters(), lr=args.lr) opt = optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.BCEWithLogitsLoss() loss_fn = nn.BCEWithLogitsLoss()
...@@ -95,7 +103,7 @@ def main(): ...@@ -95,7 +103,7 @@ def main():
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
times = [] times = []
print('---------- Training ----------') print("---------- Training ----------")
for i in range(args.epochs): for i in range(args.epochs):
t1 = time.time() t1 = time.time()
train_loss = train(model, device, train_loader, opt, loss_fn) train_loss = train(model, device, train_loader, opt, loss_fn)
...@@ -107,35 +115,49 @@ def main(): ...@@ -107,35 +115,49 @@ def main():
train_auc = test(model, device, train_loader, evaluator) train_auc = test(model, device, train_loader, evaluator)
valid_auc = test(model, device, valid_loader, evaluator) valid_auc = test(model, device, valid_loader, evaluator)
print(f'Epoch {i} | Train Loss: {train_loss:.4f} | Train Auc: {train_auc:.4f} | Valid Auc: {valid_auc:.4f}') print(
f"Epoch {i} | Train Loss: {train_loss:.4f} | Train Auc: {train_auc:.4f} | Valid Auc: {valid_auc:.4f}"
)
if valid_auc > best_auc: if valid_auc > best_auc:
best_auc = valid_auc best_auc = valid_auc
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
print('---------- Testing ----------') print("---------- Testing ----------")
test_auc = test(best_model, device, test_loader, evaluator) test_auc = test(best_model, device, test_loader, evaluator)
print(f'Test Auc: {test_auc}') print(f"Test Auc: {test_auc}")
if len(times) > 0: if len(times) > 0:
print('Times/epoch: ', sum(times) / len(times)) print("Times/epoch: ", sum(times) / len(times))
if __name__ == '__main__': if __name__ == "__main__":
""" """
DeeperGCN Hyperparameters DeeperGCN Hyperparameters
""" """
parser = argparse.ArgumentParser(description='DeeperGCN') parser = argparse.ArgumentParser(description="DeeperGCN")
# training # training
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, -1 for CPU.') parser.add_argument(
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.') "--gpu", type=int, default=-1, help="GPU index, -1 for CPU."
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') )
parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate.') parser.add_argument(
parser.add_argument('--batch-size', type=int, default=2048, help='Batch size.') "--epochs", type=int, default=300, help="Number of epochs to train."
)
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument(
"--dropout", type=float, default=0.2, help="Dropout rate."
)
parser.add_argument(
"--batch-size", type=int, default=2048, help="Batch size."
)
# model # model
parser.add_argument('--num-layers', type=int, default=7, help='Number of GNN layers.') parser.add_argument(
parser.add_argument('--hid-dim', type=int, default=256, help='Hidden channel size.') "--num-layers", type=int, default=7, help="Number of GNN layers."
)
parser.add_argument(
"--hid-dim", type=int, default=256, help="Hidden channel size."
)
# learnable parameters in aggr # learnable parameters in aggr
parser.add_argument('--learn-beta', action='store_true') parser.add_argument("--learn-beta", action="store_true")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn from layers import GENConv
from ogb.graphproppred.mol_encoder import AtomEncoder from ogb.graphproppred.mol_encoder import AtomEncoder
import dgl.function as fn
from dgl.nn.pytorch.glob import AvgPooling from dgl.nn.pytorch.glob import AvgPooling
from layers import GENConv
class DeeperGCN(nn.Module): class DeeperGCN(nn.Module):
...@@ -37,32 +37,37 @@ class DeeperGCN(nn.Module): ...@@ -37,32 +37,37 @@ class DeeperGCN(nn.Module):
mlp_layers: int mlp_layers: int
Number of MLP layers in message normalization. Default is 1. Number of MLP layers in message normalization. Default is 1.
""" """
def __init__(self,
node_feat_dim, def __init__(
edge_feat_dim, self,
hid_dim, node_feat_dim,
out_dim, edge_feat_dim,
num_layers, hid_dim,
dropout=0., out_dim,
beta=1.0, num_layers,
learn_beta=False, dropout=0.0,
aggr='softmax', beta=1.0,
mlp_layers=1): learn_beta=False,
aggr="softmax",
mlp_layers=1,
):
super(DeeperGCN, self).__init__() super(DeeperGCN, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.dropout = dropout self.dropout = dropout
self.gcns = nn.ModuleList() self.gcns = nn.ModuleList()
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
for _ in range(self.num_layers): for _ in range(self.num_layers):
conv = GENConv(in_dim=hid_dim, conv = GENConv(
out_dim=hid_dim, in_dim=hid_dim,
aggregator=aggr, out_dim=hid_dim,
beta=beta, aggregator=aggr,
learn_beta=learn_beta, beta=beta,
mlp_layers=mlp_layers) learn_beta=learn_beta,
mlp_layers=mlp_layers,
)
self.gcns.append(conv) self.gcns.append(conv)
self.norms.append(nn.BatchNorm1d(hid_dim, affine=True)) self.norms.append(nn.BatchNorm1d(hid_dim, affine=True))
......
...@@ -10,26 +10,23 @@ class MLP(nn.Sequential): ...@@ -10,26 +10,23 @@ class MLP(nn.Sequential):
----------- -----------
From equation (5) in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>" From equation (5) in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
""" """
def __init__(self,
channels, def __init__(self, channels, act="relu", dropout=0.0, bias=True):
act='relu',
dropout=0.,
bias=True):
layers = [] layers = []
for i in range(1, len(channels)): for i in range(1, len(channels)):
layers.append(nn.Linear(channels[i - 1], channels[i], bias)) layers.append(nn.Linear(channels[i - 1], channels[i], bias))
if i < len(channels) - 1: if i < len(channels) - 1:
layers.append(nn.BatchNorm1d(channels[i], affine=True)) layers.append(nn.BatchNorm1d(channels[i], affine=True))
layers.append(nn.ReLU()) layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout)) layers.append(nn.Dropout(dropout))
super(MLP, self).__init__(*layers) super(MLP, self).__init__(*layers)
class MessageNorm(nn.Module): class MessageNorm(nn.Module):
r""" r"""
Description Description
----------- -----------
Message normalization was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>" Message normalization was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
...@@ -39,9 +36,12 @@ class MessageNorm(nn.Module): ...@@ -39,9 +36,12 @@ class MessageNorm(nn.Module):
learn_scale: bool learn_scale: bool
Whether s is a learnable scaling factor or not. Default is False. Whether s is a learnable scaling factor or not. Default is False.
""" """
def __init__(self, learn_scale=False): def __init__(self, learn_scale=False):
super(MessageNorm, self).__init__() super(MessageNorm, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=learn_scale) self.scale = nn.Parameter(
torch.FloatTensor([1.0]), requires_grad=learn_scale
)
def forward(self, feats, msg, p=2): def forward(self, feats, msg, p=2):
msg = F.normalize(msg, p=2, dim=-1) msg = F.normalize(msg, p=2, dim=-1)
......
...@@ -7,16 +7,20 @@ Papers: https://arxiv.org/abs/1809.10341 ...@@ -7,16 +7,20 @@ Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI Author's code: https://github.com/PetarV-/DGI
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from gcn import GCN from gcn import GCN
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.g = g self.g = g
self.conv = GCN(g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout) self.conv = GCN(
g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout
)
def forward(self, features, corrupt=False): def forward(self, features, corrupt=False):
if corrupt: if corrupt:
...@@ -49,7 +53,9 @@ class Discriminator(nn.Module): ...@@ -49,7 +53,9 @@ class Discriminator(nn.Module):
class DGI(nn.Module): class DGI(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(DGI, self).__init__() super(DGI, self).__init__()
self.encoder = Encoder(g, in_feats, n_hidden, n_layers, activation, dropout) self.encoder = Encoder(
g, in_feats, n_hidden, n_layers, activation, dropout
)
self.discriminator = Discriminator(n_hidden) self.discriminator = Discriminator(n_hidden)
self.loss = nn.BCEWithLogitsLoss() self.loss = nn.BCEWithLogitsLoss()
......
...@@ -3,17 +3,14 @@ This code was copied from the GCN implementation in DGL examples. ...@@ -3,17 +3,14 @@ This code was copied from the GCN implementation in DGL examples.
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g self.g = g
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
...@@ -21,7 +18,9 @@ class GCN(nn.Module): ...@@ -21,7 +18,9 @@ class GCN(nn.Module):
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layers.append(GraphConv(n_hidden, n_classes)) self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
......
...@@ -5,10 +5,10 @@ and will be loaded when setting up.""" ...@@ -5,10 +5,10 @@ and will be loaded when setting up."""
def dataset_based_configure(opts): def dataset_based_configure(opts):
if opts['dataset'] == 'cycles': if opts["dataset"] == "cycles":
ds_configure = cycles_configure ds_configure = cycles_configure
else: else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset'])) raise ValueError("Unsupported dataset: {}".format(opts["dataset"]))
opts = {**opts, **ds_configure} opts = {**opts, **ds_configure}
...@@ -16,19 +16,19 @@ def dataset_based_configure(opts): ...@@ -16,19 +16,19 @@ def dataset_based_configure(opts):
synthetic_dataset_configure = { synthetic_dataset_configure = {
'node_hidden_size': 16, "node_hidden_size": 16,
'num_propagation_rounds': 2, "num_propagation_rounds": 2,
'optimizer': 'Adam', "optimizer": "Adam",
'nepochs': 25, "nepochs": 25,
'ds_size': 4000, "ds_size": 4000,
'num_generated_samples': 10000, "num_generated_samples": 10000,
} }
cycles_configure = { cycles_configure = {
**synthetic_dataset_configure, **synthetic_dataset_configure,
**{ **{
'min_size': 10, "min_size": 10,
'max_size': 20, "max_size": 20,
'lr': 5e-4, "lr": 5e-4,
} },
} }
import matplotlib.pyplot as plt
import networkx as nx
import os import os
import pickle import pickle
import random import random
import matplotlib.pyplot as plt
import networkx as nx
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -53,7 +54,9 @@ def get_decision_sequence(size): ...@@ -53,7 +54,9 @@ def get_decision_sequence(size):
if i != 0: if i != 0:
decision_sequence.append(0) # Add edge decision_sequence.append(0) # Add edge
decision_sequence.append(i - 1) # Set destination to be previous node. decision_sequence.append(
i - 1
) # Set destination to be previous node.
if i == size - 1: if i == size - 1:
decision_sequence.append(0) # Add edge decision_sequence.append(0) # Add edge
...@@ -72,7 +75,7 @@ def generate_dataset(v_min, v_max, n_samples, fname): ...@@ -72,7 +75,7 @@ def generate_dataset(v_min, v_max, n_samples, fname):
size = random.randint(v_min, v_max) size = random.randint(v_min, v_max)
samples.append(get_decision_sequence(size)) samples.append(get_decision_sequence(size))
with open(fname, 'wb') as f: with open(fname, "wb") as f:
pickle.dump(samples, f) pickle.dump(samples, f)
...@@ -80,7 +83,7 @@ class CycleDataset(Dataset): ...@@ -80,7 +83,7 @@ class CycleDataset(Dataset):
def __init__(self, fname): def __init__(self, fname):
super(CycleDataset, self).__init__() super(CycleDataset, self).__init__()
with open(fname, 'rb') as f: with open(fname, "rb") as f:
self.dataset = pickle.load(f) self.dataset = pickle.load(f)
def __len__(self): def __len__(self):
...@@ -90,7 +93,7 @@ class CycleDataset(Dataset): ...@@ -90,7 +93,7 @@ class CycleDataset(Dataset):
return self.dataset[index] return self.dataset[index]
def collate_single(self, batch): def collate_single(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training' assert len(batch) == 1, "Currently we do not support batched training"
return batch[0] return batch[0]
def collate_batch(self, batch): def collate_batch(self, batch):
...@@ -116,7 +119,7 @@ class CycleModelEvaluation(object): ...@@ -116,7 +119,7 @@ class CycleModelEvaluation(object):
self.dir = dir self.dir = dir
def rollout_and_examine(self, model, num_samples): def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().' assert not model.training, "You need to call model.eval()."
num_total_size = 0 num_total_size = 0
num_valid_size = 0 num_valid_size = 0
...@@ -139,7 +142,7 @@ class CycleModelEvaluation(object): ...@@ -139,7 +142,7 @@ class CycleModelEvaluation(object):
adj_lists_to_plot.append(sampled_adj_list) adj_lists_to_plot.append(sampled_adj_list)
graph_size = sampled_graph.number_of_nodes() graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= graph_size <= self.v_max) valid_size = self.v_min <= graph_size <= self.v_max
cycle = is_cycle(sampled_graph) cycle = is_cycle(sampled_graph)
num_total_size += graph_size num_total_size += graph_size
...@@ -158,10 +161,13 @@ class CycleModelEvaluation(object): ...@@ -158,10 +161,13 @@ class CycleModelEvaluation(object):
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2) fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3} axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
for i in range(4): for i in range(4):
nx.draw_circular(nx.from_dict_of_lists(adj_lists_to_plot[i]), nx.draw_circular(
with_labels=True, ax=axes[i]) nx.from_dict_of_lists(adj_lists_to_plot[i]),
with_labels=True,
ax=axes[i],
)
plt.savefig(self.dir + '/samples/{:d}'.format(plot_times)) plt.savefig(self.dir + "/samples/{:d}".format(plot_times))
plt.close() plt.close()
adj_lists_to_plot = [] adj_lists_to_plot = []
...@@ -173,33 +179,32 @@ class CycleModelEvaluation(object): ...@@ -173,33 +179,32 @@ class CycleModelEvaluation(object):
self.valid_ratio = num_valid / num_samples self.valid_ratio = num_valid / num_samples
def write_summary(self): def write_summary(self):
def _format_value(v): def _format_value(v):
if isinstance(v, float): if isinstance(v, float):
return '{:.4f}'.format(v) return "{:.4f}".format(v)
elif isinstance(v, int): elif isinstance(v, int):
return '{:d}'.format(v) return "{:d}".format(v)
else: else:
return '{}'.format(v) return "{}".format(v)
statistics = { statistics = {
'num_samples': self.num_samples_examined, "num_samples": self.num_samples_examined,
'v_min': self.v_min, "v_min": self.v_min,
'v_max': self.v_max, "v_max": self.v_max,
'average_size': self.average_size, "average_size": self.average_size,
'valid_size_ratio': self.valid_size_ratio, "valid_size_ratio": self.valid_size_ratio,
'cycle_ratio': self.cycle_ratio, "cycle_ratio": self.cycle_ratio,
'valid_ratio': self.valid_ratio "valid_ratio": self.valid_ratio,
} }
model_eval_path = os.path.join(self.dir, 'model_eval.txt') model_eval_path = os.path.join(self.dir, "model_eval.txt")
with open(model_eval_path, 'w') as f: with open(model_eval_path, "w") as f:
for key, value in statistics.items(): for key, value in statistics.items():
msg = '{}\t{}\n'.format(key, _format_value(value)) msg = "{}\t{}\n".format(key, _format_value(value))
f.write(msg) f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path)) print("Saved model evaluation statistics to {}".format(model_eval_path))
class CyclePrinting(object): class CyclePrinting(object):
...@@ -213,8 +218,9 @@ class CyclePrinting(object): ...@@ -213,8 +218,9 @@ class CyclePrinting(object):
def update(self, epoch, metrics): def update(self, epoch, metrics):
self.batch_count = (self.batch_count) % self.num_batches + 1 self.batch_count = (self.batch_count) % self.num_batches + 1
msg = 'epoch {:d}/{:d}, batch {:d}/{:d}'.format(epoch, self.num_epochs, msg = "epoch {:d}/{:d}, batch {:d}/{:d}".format(
self.batch_count, self.num_batches) epoch, self.num_epochs, self.batch_count, self.num_batches
)
for key, value in metrics.items(): for key, value in metrics.items():
msg += ', {}: {:4f}'.format(key, value) msg += ", {}: {:4f}".format(key, value)
print(msg) print(msg)
...@@ -7,49 +7,58 @@ This implementation works with a minibatch of size 1 only for both training and ...@@ -7,49 +7,58 @@ This implementation works with a minibatch of size 1 only for both training and
import argparse import argparse
import datetime import datetime
import time import time
import torch import torch
from model import DGMG
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model import DGMG
def main(opts): def main(opts):
t1 = time.time() t1 = time.time()
# Setup dataset and data loader # Setup dataset and data loader
if opts['dataset'] == 'cycles': if opts["dataset"] == "cycles":
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset']) dataset = CycleDataset(fname=opts["path_to_dataset"])
evaluator = CycleModelEvaluation(v_min=opts['min_size'], evaluator = CycleModelEvaluation(
v_max=opts['max_size'], v_min=opts["min_size"], v_max=opts["max_size"], dir=opts["log_dir"]
dir=opts['log_dir']) )
printer = CyclePrinting(num_epochs=opts['nepochs'], printer = CyclePrinting(
num_batches=opts['ds_size'] // opts['batch_size']) num_epochs=opts["nepochs"],
num_batches=opts["ds_size"] // opts["batch_size"],
)
else: else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset'])) raise ValueError("Unsupported dataset: {}".format(opts["dataset"]))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, data_loader = DataLoader(
collate_fn=dataset.collate_single) dataset,
batch_size=1,
shuffle=True,
num_workers=0,
collate_fn=dataset.collate_single,
)
# Initialize_model # Initialize_model
model = DGMG(v_max=opts['max_size'], model = DGMG(
node_hidden_size=opts['node_hidden_size'], v_max=opts["max_size"],
num_prop_rounds=opts['num_propagation_rounds']) node_hidden_size=opts["node_hidden_size"],
num_prop_rounds=opts["num_propagation_rounds"],
)
# Initialize optimizer # Initialize optimizer
if opts['optimizer'] == 'Adam': if opts["optimizer"] == "Adam":
optimizer = Adam(model.parameters(), lr=opts['lr']) optimizer = Adam(model.parameters(), lr=opts["lr"])
else: else:
raise ValueError('Unsupported argument for the optimizer') raise ValueError("Unsupported argument for the optimizer")
t2 = time.time() t2 = time.time()
# Training # Training
model.train() model.train()
for epoch in range(opts['nepochs']): for epoch in range(opts["nepochs"]):
batch_count = 0 batch_count = 0
batch_loss = 0 batch_loss = 0
batch_prob = 0 batch_prob = 0
...@@ -60,8 +69,8 @@ def main(opts): ...@@ -60,8 +69,8 @@ def main(opts):
log_prob = model(actions=data) log_prob = model(actions=data)
prob = log_prob.detach().exp() prob = log_prob.detach().exp()
loss = - log_prob / opts['batch_size'] loss = -log_prob / opts["batch_size"]
prob_averaged = prob / opts['batch_size'] prob_averaged = prob / opts["batch_size"]
loss.backward() loss.backward()
...@@ -69,12 +78,14 @@ def main(opts): ...@@ -69,12 +78,14 @@ def main(opts):
batch_prob += prob_averaged.item() batch_prob += prob_averaged.item()
batch_count += 1 batch_count += 1
if batch_count % opts['batch_size'] == 0: if batch_count % opts["batch_size"] == 0:
printer.update(epoch + 1, {'averaged_loss': batch_loss, printer.update(
'averaged_prob': batch_prob}) epoch + 1,
{"averaged_loss": batch_loss, "averaged_prob": batch_prob},
)
if opts['clip_grad']: if opts["clip_grad"]:
clip_grad_norm_(model.parameters(), opts['clip_bound']) clip_grad_norm_(model.parameters(), opts["clip_bound"])
optimizer.step() optimizer.step()
...@@ -85,50 +96,84 @@ def main(opts): ...@@ -85,50 +96,84 @@ def main(opts):
t3 = time.time() t3 = time.time()
model.eval() model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples']) evaluator.rollout_and_examine(model, opts["num_generated_samples"])
evaluator.write_summary() evaluator.write_summary()
t4 = time.time() t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1))) print("It took {} to setup.".format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2))) print(
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3))) "It took {} to finish training.".format(
print('--------------------------------------------------------------------------') datetime.timedelta(seconds=t3 - t2)
print('On average, an epoch takes {}.'.format(datetime.timedelta( )
seconds=(t3-t2) / opts['nepochs']))) )
print(
"It took {} to finish evaluation.".format(
datetime.timedelta(seconds=t4 - t3)
)
)
print(
"--------------------------------------------------------------------------"
)
print(
"On average, an epoch takes {}.".format(
datetime.timedelta(seconds=(t3 - t2) / opts["nepochs"])
)
)
del model.g del model.g
torch.save(model, './model.pth') torch.save(model, "./model.pth")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DGMG') parser = argparse.ArgumentParser(description="DGMG")
# configure # configure
parser.add_argument('--seed', type=int, default=9284, help='random seed') parser.add_argument("--seed", type=int, default=9284, help="random seed")
# dataset # dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles', parser.add_argument(
help='dataset to use') "--dataset", choices=["cycles"], default="cycles", help="dataset to use"
parser.add_argument('--path-to-dataset', type=str, default='cycles.p', )
help='load the dataset if it exists, ' parser.add_argument(
'generate it and save to the path otherwise') "--path-to-dataset",
type=str,
default="cycles.p",
help="load the dataset if it exists, "
"generate it and save to the path otherwise",
)
# log # log
parser.add_argument('--log-dir', default='./results', parser.add_argument(
help='folder to save info like experiment configuration ' "--log-dir",
'or model evaluation results') default="./results",
help="folder to save info like experiment configuration "
"or model evaluation results",
)
# optimization # optimization
parser.add_argument('--batch-size', type=int, default=10, parser.add_argument(
help='batch size to use for training') "--batch-size",
parser.add_argument('--clip-grad', action='store_true', default=True, type=int,
help='gradient clipping is required to prevent gradient explosion') default=10,
parser.add_argument('--clip-bound', type=float, default=0.25, help="batch size to use for training",
help='constraint of gradient norm for gradient clipping') )
parser.add_argument(
"--clip-grad",
action="store_true",
default=True,
help="gradient clipping is required to prevent gradient explosion",
)
parser.add_argument(
"--clip-bound",
type=float,
default=0.25,
help="constraint of gradient norm for gradient clipping",
)
args = parser.parse_args() args = parser.parse_args()
from utils import setup from utils import setup
opts = setup(args) opts = setup(args)
main(opts) main(opts)
import datetime import datetime
import matplotlib.pyplot as plt
import os import os
import random import random
from pprint import pprint
import matplotlib.pyplot as plt
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
from pprint import pprint
######################################################################################################################## ########################################################################################################################
# configuration # # configuration #
######################################################################################################################## ########################################################################################################################
def mkdir_p(path): def mkdir_p(path):
import errno import errno
try: try:
os.makedirs(path) os.makedirs(path)
print('Created directory {}'.format(path)) print("Created directory {}".format(path))
except OSError as exc: except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path): if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path)) print("Directory {} already exists.".format(path))
else: else:
raise raise
def date_filename(base_dir='./'):
def date_filename(base_dir="./"):
dt = datetime.datetime.now() dt = datetime.datetime.now()
return os.path.join(base_dir, '{}_{:02d}-{:02d}-{:02d}'.format( return os.path.join(
dt.date(), dt.hour, dt.minute, dt.second base_dir,
)) "{}_{:02d}-{:02d}-{:02d}".format(
dt.date(), dt.hour, dt.minute, dt.second
),
)
def setup_log_dir(opts): def setup_log_dir(opts):
log_dir = '{}'.format(date_filename(opts['log_dir'])) log_dir = "{}".format(date_filename(opts["log_dir"]))
mkdir_p(log_dir) mkdir_p(log_dir)
return log_dir return log_dir
def save_arg_dict(opts, filename='settings.txt'):
def save_arg_dict(opts, filename="settings.txt"):
def _format_value(v): def _format_value(v):
if isinstance(v, float): if isinstance(v, float):
return '{:.4f}'.format(v) return "{:.4f}".format(v)
elif isinstance(v, int): elif isinstance(v, int):
return '{:d}'.format(v) return "{:d}".format(v)
else: else:
return '{}'.format(v) return "{}".format(v)
save_path = os.path.join(opts['log_dir'], filename) save_path = os.path.join(opts["log_dir"], filename)
with open(save_path, 'w') as f: with open(save_path, "w") as f:
for key, value in opts.items(): for key, value in opts.items():
f.write('{}\t{}\n'.format(key, _format_value(value))) f.write("{}\t{}\n".format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path)) print("Saved settings to {}".format(save_path))
def setup(args): def setup(args):
opts = args.__dict__.copy() opts = args.__dict__.copy()
...@@ -57,52 +66,64 @@ def setup(args): ...@@ -57,52 +66,64 @@ def setup(args):
cudnn.deterministic = True cudnn.deterministic = True
# Seed # Seed
if opts['seed'] is None: if opts["seed"] is None:
opts['seed'] = random.randint(1, 10000) opts["seed"] = random.randint(1, 10000)
random.seed(opts['seed']) random.seed(opts["seed"])
torch.manual_seed(opts['seed']) torch.manual_seed(opts["seed"])
# Dataset # Dataset
from configure import dataset_based_configure from configure import dataset_based_configure
opts = dataset_based_configure(opts) opts = dataset_based_configure(opts)
assert opts['path_to_dataset'] is not None, 'Expect path to dataset to be set.' assert (
if not os.path.exists(opts['path_to_dataset']): opts["path_to_dataset"] is not None
if opts['dataset'] == 'cycles': ), "Expect path to dataset to be set."
if not os.path.exists(opts["path_to_dataset"]):
if opts["dataset"] == "cycles":
from cycles import generate_dataset from cycles import generate_dataset
generate_dataset(opts['min_size'], opts['max_size'],
opts['ds_size'], opts['path_to_dataset']) generate_dataset(
opts["min_size"],
opts["max_size"],
opts["ds_size"],
opts["path_to_dataset"],
)
else: else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset'])) raise ValueError("Unsupported dataset: {}".format(opts["dataset"]))
# Optimization # Optimization
if opts['clip_grad']: if opts["clip_grad"]:
assert opts['clip_grad'] is not None, 'Expect the gradient norm constraint to be set.' assert (
opts["clip_grad"] is not None
), "Expect the gradient norm constraint to be set."
# Log # Log
print('Prepare logging directory...') print("Prepare logging directory...")
log_dir = setup_log_dir(opts) log_dir = setup_log_dir(opts)
opts['log_dir'] = log_dir opts["log_dir"] = log_dir
mkdir_p(log_dir + '/samples') mkdir_p(log_dir + "/samples")
plt.switch_backend('Agg') plt.switch_backend("Agg")
save_arg_dict(opts) save_arg_dict(opts)
pprint(opts) pprint(opts)
return opts return opts
######################################################################################################################## ########################################################################################################################
# model # # model #
######################################################################################################################## ########################################################################################################################
def weights_init(m): def weights_init(m):
''' """
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5 Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage: Usage:
model = Model() model = Model()
model.apply(weight_init) model.apply(weight_init)
''' """
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data) init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data) init.normal_(m.bias.data)
...@@ -113,18 +134,20 @@ def weights_init(m): ...@@ -113,18 +134,20 @@ def weights_init(m):
else: else:
init.normal_(param.data) init.normal_(param.data)
def dgmg_message_weight_init(m): def dgmg_message_weight_init(m):
""" """
This is similar as the function above where we initialize linear layers from a normal distribution with std This is similar as the function above where we initialize linear layers from a normal distribution with std
1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the 1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
paper. paper.
""" """
def _weight_init(m): def _weight_init(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10) init.normal_(m.weight.data, std=1.0 / 10)
init.normal_(m.bias.data, std=1./10) init.normal_(m.bias.data, std=1.0 / 10)
else: else:
raise ValueError('Expected the input to be of type nn.Linear!') raise ValueError("Expected the input to be of type nn.Linear!")
if isinstance(m, nn.ModuleList): if isinstance(m, nn.ModuleList):
for layer in m: for layer in m:
......
...@@ -3,9 +3,9 @@ import torch ...@@ -3,9 +3,9 @@ import torch
def one_hotify(labels, pad=-1): def one_hotify(labels, pad=-1):
''' """
cast label to one hot vector cast label to one hot vector
''' """
num_instances = len(labels) num_instances = len(labels)
if pad <= 0: if pad <= 0:
dim_embedding = np.max(labels) + 1 # zero-indexed assumed dim_embedding = np.max(labels) + 1 # zero-indexed assumed
...@@ -24,17 +24,17 @@ def pre_process(dataset, prog_args): ...@@ -24,17 +24,17 @@ def pre_process(dataset, prog_args):
""" """
if prog_args.data_mode != "default": if prog_args.data_mode != "default":
print("overwrite node attributes with DiffPool's preprocess setting") print("overwrite node attributes with DiffPool's preprocess setting")
if prog_args.data_mode == 'id': if prog_args.data_mode == "id":
for g, _ in dataset: for g, _ in dataset:
id_list = np.arange(g.number_of_nodes()) id_list = np.arange(g.number_of_nodes())
g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node) g.ndata["feat"] = one_hotify(id_list, pad=dataset.max_num_node)
elif prog_args.data_mode == 'deg-num': elif prog_args.data_mode == "deg-num":
for g, _ in dataset: for g, _ in dataset:
g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1) g.ndata["feat"] = np.expand_dims(g.in_degrees(), axis=1)
elif prog_args.data_mode == 'deg': elif prog_args.data_mode == "deg":
for g in dataset: for g in dataset:
degs = list(g.in_degrees()) degs = list(g.in_degrees())
degs_one_hot = one_hotify(degs, pad=dataset.max_degrees) degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
g.ndata['feat'] = degs_one_hot g.ndata["feat"] = degs_one_hot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment