Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
from ogb.nodeproppred import DglNodePropPredDataset
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
def forward(self, sg, x): def forward(self, sg, x):
...@@ -26,37 +29,43 @@ class SAGE(nn.Module): ...@@ -26,37 +29,43 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
graph = dataset[0] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
graph = dataset[
0
] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata["feat"].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
num_partitions = 1000 num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler( sampler = dgl.dataloading.ClusterGCNSampler(
graph, num_partitions, graph,
prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask']) num_partitions,
prefetch_ndata=["feat", "label", "train_mask", "val_mask", "test_mask"],
)
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like # DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler. # partition IDs here), and a graph sampler.
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
graph, graph,
torch.arange(num_partitions).to('cuda'), torch.arange(num_partitions).to("cuda"),
sampler, sampler,
device='cuda', device="cuda",
batch_size=100, batch_size=100,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0, num_workers=0,
use_uva=True) use_uva=True,
)
durations = [] durations = []
for _ in range(10): for _ in range(10):
t0 = time.time() t0 = time.time()
model.train() model.train()
for it, sg in enumerate(dataloader): for it, sg in enumerate(dataloader):
x = sg.ndata['feat'] x = sg.ndata["feat"]
y = sg.ndata['label'] y = sg.ndata["label"]
m = sg.ndata['train_mask'].bool() m = sg.ndata["train_mask"].bool()
y_hat = model(sg, x) y_hat = model(sg, x)
loss = F.cross_entropy(y_hat[m], y[m]) loss = F.cross_entropy(y_hat[m], y[m])
opt.zero_grad() opt.zero_grad()
...@@ -65,7 +74,7 @@ for _ in range(10): ...@@ -65,7 +74,7 @@ for _ in range(10):
if it % 20 == 0: if it % 20 == 0:
acc = MF.accuracy(y_hat[m], y[m]) acc = MF.accuracy(y_hat[m], y[m])
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
...@@ -75,10 +84,10 @@ for _ in range(10): ...@@ -75,10 +84,10 @@ for _ in range(10):
val_preds, test_preds = [], [] val_preds, test_preds = [], []
val_labels, test_labels = [], [] val_labels, test_labels = [], []
for it, sg in enumerate(dataloader): for it, sg in enumerate(dataloader):
x = sg.ndata['feat'] x = sg.ndata["feat"]
y = sg.ndata['label'] y = sg.ndata["label"]
m_val = sg.ndata['val_mask'].bool() m_val = sg.ndata["val_mask"].bool()
m_test = sg.ndata['test_mask'].bool() m_test = sg.ndata["test_mask"].bool()
y_hat = model(sg, x) y_hat = model(sg, x)
val_preds.append(y_hat[m_val]) val_preds.append(y_hat[m_val])
val_labels.append(y[m_val]) val_labels.append(y[m_val])
...@@ -90,6 +99,6 @@ for _ in range(10): ...@@ -90,6 +99,6 @@ for _ in range(10):
test_labels = torch.cat(test_labels, 0) test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels) val_acc = MF.accuracy(val_preds, val_labels)
test_acc = MF.accuracy(test_preds, test_labels) test_acc = MF.accuracy(test_preds, test_labels)
print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item()) print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import dgl
from collections import defaultdict as ddict from collections import defaultdict as ddict
import numpy as np
import torch
from ordered_set import OrderedSet from ordered_set import OrderedSet
from torch.utils.data import DataLoader, Dataset
import dgl
class TrainDataset(Dataset): class TrainDataset(Dataset):
""" """
...@@ -18,6 +21,7 @@ class TrainDataset(Dataset): ...@@ -18,6 +21,7 @@ class TrainDataset(Dataset):
------- -------
A training Dataset class instance used by DataLoader A training Dataset class instance used by DataLoader
""" """
def __init__(self, triples, num_ent, lbl_smooth): def __init__(self, triples, num_ent, lbl_smooth):
self.triples = triples self.triples = triples
self.num_ent = num_ent self.num_ent = num_ent
...@@ -29,11 +33,13 @@ class TrainDataset(Dataset): ...@@ -29,11 +33,13 @@ class TrainDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ele = self.triples[idx] ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label']) triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
trp_label = self.get_label(label) trp_label = self.get_label(label)
#label smoothing # label smoothing
if self.lbl_smooth != 0.0: if self.lbl_smooth != 0.0:
trp_label = (1.0 - self.lbl_smooth) * trp_label + (1.0 / self.num_ent) trp_label = (1.0 - self.lbl_smooth) * trp_label + (
1.0 / self.num_ent
)
return triple, trp_label return triple, trp_label
...@@ -48,10 +54,10 @@ class TrainDataset(Dataset): ...@@ -48,10 +54,10 @@ class TrainDataset(Dataset):
trp_label = torch.stack(labels, dim=0) trp_label = torch.stack(labels, dim=0)
return triple, trp_label return triple, trp_label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label): def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32) y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label: for e2 in label:
y[e2] = 1.0 y[e2] = 1.0
return torch.FloatTensor(y) return torch.FloatTensor(y)
...@@ -68,6 +74,7 @@ class TestDataset(Dataset): ...@@ -68,6 +74,7 @@ class TestDataset(Dataset):
------- -------
An evaluation Dataset class instance used by DataLoader for model evaluation An evaluation Dataset class instance used by DataLoader for model evaluation
""" """
def __init__(self, triples, num_ent): def __init__(self, triples, num_ent):
self.triples = triples self.triples = triples
self.num_ent = num_ent self.num_ent = num_ent
...@@ -77,7 +84,7 @@ class TestDataset(Dataset): ...@@ -77,7 +84,7 @@ class TestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ele = self.triples[idx] ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label']) triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"])
label = self.get_label(label) label = self.get_label(label)
return triple, label return triple, label
...@@ -93,19 +100,18 @@ class TestDataset(Dataset): ...@@ -93,19 +100,18 @@ class TestDataset(Dataset):
label = torch.stack(labels, dim=0) label = torch.stack(labels, dim=0)
return triple, label return triple, label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label): def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32) y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label: for e2 in label:
y[e2] = 1.0 y[e2] = 1.0
return torch.FloatTensor(y) return torch.FloatTensor(y)
class Data(object): class Data(object):
def __init__(self, dataset, lbl_smooth, num_workers, batch_size): def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
""" """
Reading in raw triples and converts it into a standard format. Reading in raw triples and converts it into a standard format.
Parameters Parameters
---------- ----------
dataset: The name of the dataset dataset: The name of the dataset
...@@ -133,18 +139,23 @@ class Data(object): ...@@ -133,18 +139,23 @@ class Data(object):
self.num_workers = num_workers self.num_workers = num_workers
self.batch_size = batch_size self.batch_size = batch_size
#read in raw data and get mappings # read in raw data and get mappings
ent_set, rel_set = OrderedSet(), OrderedSet() ent_set, rel_set = OrderedSet(), OrderedSet()
for split in ['train', 'test', 'valid']: for split in ["train", "test", "valid"]:
for line in open('./{}/{}.txt'.format(self.dataset, split)): for line in open("./{}/{}.txt".format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t')) sub, rel, obj = map(str.lower, line.strip().split("\t"))
ent_set.add(sub) ent_set.add(sub)
rel_set.add(rel) rel_set.add(rel)
ent_set.add(obj) ent_set.add(obj)
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)}) self.rel2id.update(
{
rel + "_reverse": idx + len(self.rel2id)
for idx, rel in enumerate(rel_set)
}
)
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()} self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()} self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
...@@ -152,92 +163,121 @@ class Data(object): ...@@ -152,92 +163,121 @@ class Data(object):
self.num_ent = len(self.ent2id) self.num_ent = len(self.ent2id)
self.num_rel = len(self.rel2id) // 2 self.num_rel = len(self.rel2id) // 2
#read in ids of subjects, relations, and objects for train/test/valid # read in ids of subjects, relations, and objects for train/test/valid
self.data = ddict(list) #stores the triples self.data = ddict(list) # stores the triples
sr2o = ddict(set) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation) sr2o = ddict(
src=[] set
dst=[] ) # The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
src = []
dst = []
rels = [] rels = []
inver_src = [] inver_src = []
inver_dst = [] inver_dst = []
inver_rels = [] inver_rels = []
for split in ['train', 'test', 'valid']: for split in ["train", "test", "valid"]:
for line in open('./{}/{}.txt'.format(self.dataset, split)): for line in open("./{}/{}.txt".format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t')) sub, rel, obj = map(str.lower, line.strip().split("\t"))
sub_id, rel_id, obj_id = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] sub_id, rel_id, obj_id = (
self.ent2id[sub],
self.rel2id[rel],
self.ent2id[obj],
)
self.data[split].append((sub_id, rel_id, obj_id)) self.data[split].append((sub_id, rel_id, obj_id))
if split == 'train': if split == "train":
sr2o[(sub_id, rel_id)].add(obj_id) sr2o[(sub_id, rel_id)].add(obj_id)
sr2o[(obj_id, rel_id+self.num_rel)].add(sub_id) #append the reversed edges sr2o[(obj_id, rel_id + self.num_rel)].add(
sub_id
) # append the reversed edges
src.append(sub_id) src.append(sub_id)
dst.append(obj_id) dst.append(obj_id)
rels.append(rel_id) rels.append(rel_id)
inver_src.append(obj_id) inver_src.append(obj_id)
inver_dst.append(sub_id) inver_dst.append(sub_id)
inver_rels.append(rel_id+self.num_rel) inver_rels.append(rel_id + self.num_rel)
#construct dgl graph # construct dgl graph
src = src + inver_src src = src + inver_src
dst = dst + inver_dst dst = dst + inver_dst
rels = rels + inver_rels rels = rels + inver_rels
self.g = dgl.graph((src, dst), num_nodes=self.num_ent) self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
self.g.edata['etype'] = torch.Tensor(rels).long() self.g.edata["etype"] = torch.Tensor(rels).long()
#identify in and out edges # identify in and out edges
in_edges_mask = [True] * (self.g.num_edges()//2) + [False] * (self.g.num_edges()//2) in_edges_mask = [True] * (self.g.num_edges() // 2) + [False] * (
out_edges_mask = [False] * (self.g.num_edges()//2) + [True] * (self.g.num_edges()//2) self.g.num_edges() // 2
self.g.edata['in_edges_mask'] = torch.Tensor(in_edges_mask) )
self.g.edata['out_edges_mask'] = torch.Tensor(out_edges_mask) out_edges_mask = [False] * (self.g.num_edges() // 2) + [True] * (
self.g.num_edges() // 2
#Prepare train/valid/test data )
self.g.edata["in_edges_mask"] = torch.Tensor(in_edges_mask)
self.g.edata["out_edges_mask"] = torch.Tensor(out_edges_mask)
# Prepare train/valid/test data
self.data = dict(self.data) self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()} #store only the train data self.sr2o = {
k: list(v) for k, v in sr2o.items()
} # store only the train data
for split in ['test', 'valid']: for split in ["test", "valid"]:
for sub, rel, obj in self.data[split]: for sub, rel, obj in self.data[split]:
sr2o[(sub, rel)].add(obj) sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel+self.num_rel)].add(sub) sr2o[(obj, rel + self.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()} #store all the data self.sr2o_all = {
self.triples = ddict(list) k: list(v) for k, v in sr2o.items()
} # store all the data
self.triples = ddict(list)
for (sub, rel), obj in self.sr2o.items(): for (sub, rel), obj in self.sr2o.items():
self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)]}) self.triples["train"].append(
{"triple": (sub, rel, -1), "label": self.sr2o[(sub, rel)]}
)
for split in ['test', 'valid']: for split in ["test", "valid"]:
for sub, rel, obj in self.data[split]: for sub, rel, obj in self.data[split]:
rel_inv = rel + self.num_rel rel_inv = rel + self.num_rel
self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]}) self.triples["{}_{}".format(split, "tail")].append(
self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) {
"triple": (sub, rel, obj),
"label": self.sr2o_all[(sub, rel)],
}
)
self.triples["{}_{}".format(split, "head")].append(
{
"triple": (obj, rel_inv, sub),
"label": self.sr2o_all[(obj, rel_inv)],
}
)
self.triples = dict(self.triples) self.triples = dict(self.triples)
def get_train_data_loader(split, batch_size, shuffle=True): def get_train_data_loader(split, batch_size, shuffle=True):
return DataLoader( return DataLoader(
TrainDataset(self.triples[split], self.num_ent, self.lbl_smooth), TrainDataset(
batch_size = batch_size, self.triples[split], self.num_ent, self.lbl_smooth
shuffle = shuffle, ),
num_workers = max(0, self.num_workers), batch_size=batch_size,
collate_fn = TrainDataset.collate_fn shuffle=shuffle,
) num_workers=max(0, self.num_workers),
collate_fn=TrainDataset.collate_fn,
)
def get_test_data_loader(split, batch_size, shuffle=True): def get_test_data_loader(split, batch_size, shuffle=True):
return DataLoader( return DataLoader(
TestDataset(self.triples[split], self.num_ent), TestDataset(self.triples[split], self.num_ent),
batch_size = batch_size, batch_size=batch_size,
shuffle = shuffle, shuffle=shuffle,
num_workers = max(0, self.num_workers), num_workers=max(0, self.num_workers),
collate_fn = TestDataset.collate_fn collate_fn=TestDataset.collate_fn,
) )
#train/valid/test dataloaders # train/valid/test dataloaders
self.data_iter = { self.data_iter = {
'train': get_train_data_loader('train', self.batch_size), "train": get_train_data_loader("train", self.batch_size),
'valid_head': get_test_data_loader('valid_head', self.batch_size), "valid_head": get_test_data_loader("valid_head", self.batch_size),
'valid_tail': get_test_data_loader('valid_tail', self.batch_size), "valid_tail": get_test_data_loader("valid_tail", self.batch_size),
'test_head': get_test_data_loader('test_head', self.batch_size), "test_head": get_test_data_loader("test_head", self.batch_size),
'test_tail': get_test_data_loader('test_tail', self.batch_size), "test_tail": get_test_data_loader("test_tail", self.batch_size),
} }
\ No newline at end of file
import argparse import argparse
from time import time
import numpy as np
import torch as th import torch as th
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn import torch.optim as optim
from data_loader import Data
from models import CompGCN_ConvE
from utils import in_out_norm from utils import in_out_norm
from models import CompGCN_ConvE import dgl.function as fn
from data_loader import Data
import numpy as np
from time import time
#predict the tail for (head, rel, -1) or head for (-1, rel, tail) # predict the tail for (head, rel, -1) or head for (-1, rel, tail)
def predict(model, graph, device, data_iter, split='valid', mode='tail'): def predict(model, graph, device, data_iter, split="valid", mode="tail"):
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
results = {} results = {}
train_iter = iter(data_iter['{}_{}'.format(split, mode)]) train_iter = iter(data_iter["{}_{}".format(split, mode)])
for step, batch in enumerate(train_iter): for step, batch in enumerate(train_iter):
triple, label = batch[0].to(device), batch[1].to(device) triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label sub, rel, obj, label = (
triple[:, 0],
triple[:, 1],
triple[:, 2],
label,
)
pred = model(graph, sub, rel) pred = model(graph, sub, rel)
b_range = th.arange(pred.size()[0], device = device) b_range = th.arange(pred.size()[0], device=device)
target_pred = pred[b_range, obj] target_pred = pred[b_range, obj]
pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred) pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred pred[b_range, obj] = target_pred
#compute metrics # compute metrics
ranks = 1 + th.argsort(th.argsort(pred, dim=1, descending=True), dim =1, descending=False)[b_range, obj] ranks = (
1
+ th.argsort(
th.argsort(pred, dim=1, descending=True),
dim=1,
descending=False,
)[b_range, obj]
)
ranks = ranks.float() ranks = ranks.float()
results['count'] = th.numel(ranks) + results.get('count', 0.0) results["count"] = th.numel(ranks) + results.get("count", 0.0)
results['mr'] = th.sum(ranks).item() + results.get('mr', 0.0) results["mr"] = th.sum(ranks).item() + results.get("mr", 0.0)
results['mrr'] = th.sum(1.0/ranks).item() + results.get('mrr', 0.0) results["mrr"] = th.sum(1.0 / ranks).item() + results.get(
for k in [1,3,10]: "mrr", 0.0
results['hits@{}'.format(k)] = th.numel(ranks[ranks <= (k)]) + results.get('hits@{}'.format(k), 0.0) )
for k in [1, 3, 10]:
results["hits@{}".format(k)] = th.numel(
ranks[ranks <= (k)]
) + results.get("hits@{}".format(k), 0.0)
return results return results
#evaluation function, evaluate the head and tail prediction and then combine the results
def evaluate(model, graph, device, data_iter, split='valid'): # evaluation function, evaluate the head and tail prediction and then combine the results
#predict for head and tail def evaluate(model, graph, device, data_iter, split="valid"):
left_results = predict(model, graph, device, data_iter, split, mode='tail') # predict for head and tail
right_results = predict(model, graph, device, data_iter, split, mode='head') left_results = predict(model, graph, device, data_iter, split, mode="tail")
right_results = predict(model, graph, device, data_iter, split, mode="head")
results = {} results = {}
count = float(left_results['count']) count = float(left_results["count"])
#combine the head and tail prediction results # combine the head and tail prediction results
#Metrics: MRR, MR, and Hit@k # Metrics: MRR, MR, and Hit@k
results['left_mr'] = round(left_results['mr']/count, 5) results["left_mr"] = round(left_results["mr"] / count, 5)
results['left_mrr'] = round(left_results['mrr']/count, 5) results["left_mrr"] = round(left_results["mrr"] / count, 5)
results['right_mr'] = round(right_results['mr']/count, 5) results["right_mr"] = round(right_results["mr"] / count, 5)
results['right_mrr'] = round(right_results['mrr']/count, 5) results["right_mrr"] = round(right_results["mrr"] / count, 5)
results['mr'] = round((left_results['mr'] + right_results['mr']) /(2*count), 5) results["mr"] = round(
results['mrr'] = round((left_results['mrr'] + right_results['mrr']) /(2*count), 5) (left_results["mr"] + right_results["mr"]) / (2 * count), 5
for k in [1,3,10]: )
results['left_hits@{}'.format(k)] = round(left_results['hits@{}'.format(k)]/count, 5) results["mrr"] = round(
results['right_hits@{}'.format(k)] = round(right_results['hits@{}'.format(k)]/count, 5) (left_results["mrr"] + right_results["mrr"]) / (2 * count), 5
results['hits@{}'.format(k)] = round((left_results['hits@{}'.format(k)] + right_results['hits@{}'.format(k)])/(2*count), 5) )
return results for k in [1, 3, 10]:
results["left_hits@{}".format(k)] = round(
left_results["hits@{}".format(k)] / count, 5
)
results["right_hits@{}".format(k)] = round(
right_results["hits@{}".format(k)] / count, 5
)
results["hits@{}".format(k)] = round(
(
left_results["hits@{}".format(k)]
+ right_results["hits@{}".format(k)]
)
/ (2 * count),
5,
)
return results
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
#construct graph, split in/out edges and prepare train/validation/test data_loader # construct graph, split in/out edges and prepare train/validation/test data_loader
data = Data(args.dataset, args.lbl_smooth, args.num_workers, args.batch_size) data = Data(
data_iter = data.data_iter #train/validation/test data_loader args.dataset, args.lbl_smooth, args.num_workers, args.batch_size
)
data_iter = data.data_iter # train/validation/test data_loader
graph = data.g.to(device) graph = data.g.to(device)
num_rel = th.max(graph.edata['etype']).item() + 1 num_rel = th.max(graph.edata["etype"]).item() + 1
#Compute in/out edge norms and store in edata # Compute in/out edge norms and store in edata
graph = in_out_norm(graph) graph = in_out_norm(graph)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
compgcn_model=CompGCN_ConvE(num_bases=args.num_bases, compgcn_model = CompGCN_ConvE(
num_rel=num_rel, num_bases=args.num_bases,
num_ent=graph.num_nodes(), num_rel=num_rel,
in_dim=args.init_dim, num_ent=graph.num_nodes(),
layer_size=args.layer_size, in_dim=args.init_dim,
comp_fn=args.opn, layer_size=args.layer_size,
batchnorm=True, comp_fn=args.opn,
dropout=args.dropout, batchnorm=True,
layer_dropout=args.layer_dropout, dropout=args.dropout,
num_filt=args.num_filt, layer_dropout=args.layer_dropout,
hid_drop=args.hid_drop, num_filt=args.num_filt,
feat_drop=args.feat_drop, hid_drop=args.hid_drop,
ker_sz=args.ker_sz, feat_drop=args.feat_drop,
k_w=args.k_w, ker_sz=args.ker_sz,
k_h=args.k_h k_w=args.k_w,
) k_h=args.k_h,
)
compgcn_model = compgcn_model.to(device) compgcn_model = compgcn_model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
loss_fn = th.nn.BCELoss() loss_fn = th.nn.BCELoss()
optimizer = optim.Adam(compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2) optimizer = optim.Adam(
compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
best_mrr = 0.0 best_mrr = 0.0
kill_cnt = 0 kill_cnt = 0
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
# Training and validation using a full graph # Training and validation using a full graph
compgcn_model.train() compgcn_model.train()
train_loss=[] train_loss = []
t0 = time() t0 = time()
for step, batch in enumerate(data_iter['train']): for step, batch in enumerate(data_iter["train"]):
triple, label = batch[0].to(device), batch[1].to(device) triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label sub, rel, obj, label = (
triple[:, 0],
triple[:, 1],
triple[:, 2],
label,
)
logits = compgcn_model(graph, sub, rel) logits = compgcn_model(graph, sub, rel)
# compute loss # compute loss
tr_loss = loss_fn(logits, label) tr_loss = loss_fn(logits, label)
train_loss.append(tr_loss.item()) train_loss.append(tr_loss.item())
...@@ -129,66 +170,192 @@ def main(args): ...@@ -129,66 +170,192 @@ def main(args):
train_loss = np.sum(train_loss) train_loss = np.sum(train_loss)
t1 = time() t1 = time()
val_results = evaluate(compgcn_model, graph, device, data_iter, split='valid') val_results = evaluate(
compgcn_model, graph, device, data_iter, split="valid"
)
t2 = time() t2 = time()
#validate # validate
if val_results['mrr']>best_mrr: if val_results["mrr"] > best_mrr:
best_mrr = val_results['mrr'] best_mrr = val_results["mrr"]
best_epoch = epoch best_epoch = epoch
th.save(compgcn_model.state_dict(), 'comp_link'+'_'+args.dataset) th.save(
compgcn_model.state_dict(), "comp_link" + "_" + args.dataset
)
kill_cnt = 0 kill_cnt = 0
print("saving model...") print("saving model...")
else: else:
kill_cnt += 1 kill_cnt += 1
if kill_cnt > 100: if kill_cnt > 100:
print('early stop.') print("early stop.")
break break
print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\ print(
.format(epoch, train_loss, val_results['mrr'], t1-t0, t2-t1)) "In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}".format(
epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1
#test use the best model )
)
# test use the best model
compgcn_model.eval() compgcn_model.eval()
compgcn_model.load_state_dict(th.load('comp_link'+'_'+args.dataset)) compgcn_model.load_state_dict(th.load("comp_link" + "_" + args.dataset))
test_results = evaluate(compgcn_model, graph, device, data_iter, split='test') test_results = evaluate(
print("Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n"\ compgcn_model, graph, device, data_iter, split="test"
.format(test_results['mrr'], test_results['mr'], test_results['hits@10'], test_results['hits@3'], test_results['hits@1'])) )
print(
"Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n".format(
if __name__ == '__main__': test_results["mrr"],
parser = argparse.ArgumentParser(description='Parser For Arguments', formatter_class=argparse.ArgumentDefaultsHelpFormatter) test_results["mr"],
test_results["hits@10"],
parser.add_argument('--data', dest='dataset', default='FB15k-237', help='Dataset to use, default: FB15k-237') test_results["hits@3"],
parser.add_argument('--model', dest='model', default='compgcn', help='Model Name') test_results["hits@1"],
parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction') )
parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN') )
parser.add_argument('--batch', dest='batch_size', default=1024, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0') if __name__ == "__main__":
parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs') parser = argparse.ArgumentParser(
parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer') description="Parser For Arguments",
parser.add_argument('--lr', type=float, default=0.001, help='Starting Learning Rate') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--lbl_smooth', dest='lbl_smooth', type=float, default=0.1, help='Label Smoothing') )
parser.add_argument('--num_workers', type=int, default=10, help='Number of processes to construct batches')
parser.add_argument('--seed', dest='seed', default=41504, type=int, help='Seed for randomization') parser.add_argument(
"--data",
parser.add_argument('--num_bases', dest='num_bases', default=-1, type=int, help='Number of basis relation vectors to use') dest="dataset",
parser.add_argument('--init_dim', dest='init_dim', default=100, type=int, help='Initial dimension size for entities and relations') default="FB15k-237",
parser.add_argument('--layer_size', nargs='?', default='[200]', help='List of output size for each compGCN layer') help="Dataset to use, default: FB15k-237",
parser.add_argument('--gcn_drop', dest='dropout', default=0.1, type=float, help='Dropout to use in GCN Layer') )
parser.add_argument('--layer_dropout', nargs='?', default='[0.3]', help='List of dropout value after each compGCN layer') parser.add_argument(
"--model", dest="model", default="compgcn", help="Model Name"
)
parser.add_argument(
"--score_func",
dest="score_func",
default="conve",
help="Score Function for Link prediction",
)
parser.add_argument(
"--opn",
dest="opn",
default="ccorr",
help="Composition Operation to be used in CompGCN",
)
parser.add_argument(
"--batch", dest="batch_size", default=1024, type=int, help="Batch size"
)
parser.add_argument(
"--gpu",
type=int,
default="0",
help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
)
parser.add_argument(
"--epoch",
dest="max_epochs",
type=int,
default=500,
help="Number of epochs",
)
parser.add_argument(
"--l2", type=float, default=0.0, help="L2 Regularization for Optimizer"
)
parser.add_argument(
"--lr", type=float, default=0.001, help="Starting Learning Rate"
)
parser.add_argument(
"--lbl_smooth",
dest="lbl_smooth",
type=float,
default=0.1,
help="Label Smoothing",
)
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of processes to construct batches",
)
parser.add_argument(
"--seed",
dest="seed",
default=41504,
type=int,
help="Seed for randomization",
)
parser.add_argument(
"--num_bases",
dest="num_bases",
default=-1,
type=int,
help="Number of basis relation vectors to use",
)
parser.add_argument(
"--init_dim",
dest="init_dim",
default=100,
type=int,
help="Initial dimension size for entities and relations",
)
parser.add_argument(
"--layer_size",
nargs="?",
default="[200]",
help="List of output size for each compGCN layer",
)
parser.add_argument(
"--gcn_drop",
dest="dropout",
default=0.1,
type=float,
help="Dropout to use in GCN Layer",
)
parser.add_argument(
"--layer_dropout",
nargs="?",
default="[0.3]",
help="List of dropout value after each compGCN layer",
)
# ConvE specific hyperparameters # ConvE specific hyperparameters
parser.add_argument('--hid_drop', dest='hid_drop', default=0.3, type=float, help='ConvE: Hidden dropout') parser.add_argument(
parser.add_argument('--feat_drop', dest='feat_drop', default=0.3, type=float, help='ConvE: Feature Dropout') "--hid_drop",
parser.add_argument('--k_w', dest='k_w', default=10, type=int, help='ConvE: k_w') dest="hid_drop",
parser.add_argument('--k_h', dest='k_h', default=20, type=int, help='ConvE: k_h') default=0.3,
parser.add_argument('--num_filt', dest='num_filt', default=200, type=int, help='ConvE: Number of filters in convolution') type=float,
parser.add_argument('--ker_sz', dest='ker_sz', default=7, type=int, help='ConvE: Kernel size to use') help="ConvE: Hidden dropout",
)
parser.add_argument(
"--feat_drop",
dest="feat_drop",
default=0.3,
type=float,
help="ConvE: Feature Dropout",
)
parser.add_argument(
"--k_w", dest="k_w", default=10, type=int, help="ConvE: k_w"
)
parser.add_argument(
"--k_h", dest="k_h", default=20, type=int, help="ConvE: k_h"
)
parser.add_argument(
"--num_filt",
dest="num_filt",
default=200,
type=int,
help="ConvE: Number of filters in convolution",
)
parser.add_argument(
"--ker_sz",
dest="ker_sz",
default=7,
type=int,
help="ConvE: Kernel size to use",
)
args = parser.parse_args() args = parser.parse_args()
np.random.seed(args.seed) np.random.seed(args.seed)
th.manual_seed(args.seed) th.manual_seed(args.seed)
...@@ -198,4 +365,3 @@ if __name__ == '__main__': ...@@ -198,4 +365,3 @@ if __name__ == '__main__':
args.layer_dropout = eval(args.layer_dropout) args.layer_dropout = eval(args.layer_dropout)
main(args) main(args)
This diff is collapsed.
...@@ -3,55 +3,65 @@ ...@@ -3,55 +3,65 @@
# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation. # It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.
import torch as th import torch as th
import dgl import dgl
def com_mult(a, b): def com_mult(a, b):
r1, i1 = a[..., 0], a[..., 1] r1, i1 = a[..., 0], a[..., 1]
r2, i2 = b[..., 0], b[..., 1] r2, i2 = b[..., 0], b[..., 1]
return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1) return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
def conj(a): def conj(a):
a[..., 1] = -a[..., 1] a[..., 1] = -a[..., 1]
return a return a
def ccorr(a, b): def ccorr(a, b):
""" """
Compute circular correlation of two tensors. Compute circular correlation of two tensors.
Parameters Parameters
---------- ----------
a: Tensor, 1D or 2D a: Tensor, 1D or 2D
b: Tensor, 1D or 2D b: Tensor, 1D or 2D
Notes Notes
----- -----
Input a and b should have the same dimensions. And this operation supports broadcasting. Input a and b should have the same dimensions. And this operation supports broadcasting.
Returns Returns
------- -------
Tensor, having the same dimension as the input a. Tensor, having the same dimension as the input a.
""" """
return th.fft.irfftn(th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)) return th.fft.irfftn(
th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)
#identify in/out edges, compute edge norm for each and store in edata )
# identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph): def in_out_norm(graph):
src, dst, EID = graph.edges(form='all') src, dst, EID = graph.edges(form="all")
graph.edata['norm'] = th.ones(EID.shape[0]).to(graph.device) graph.edata["norm"] = th.ones(EID.shape[0]).to(graph.device)
in_edges_idx = th.nonzero(graph.edata['in_edges_mask'], as_tuple=False).squeeze() in_edges_idx = th.nonzero(
out_edges_idx = th.nonzero(graph.edata['out_edges_mask'], as_tuple=False).squeeze() graph.edata["in_edges_mask"], as_tuple=False
).squeeze()
for idx in [in_edges_idx, out_edges_idx]: out_edges_idx = th.nonzero(
u, v = src[idx], dst[idx] graph.edata["out_edges_mask"], as_tuple=False
deg = th.zeros(graph.num_nodes()).to(graph.device) ).squeeze()
n_idx, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
deg[n_idx]=count.float() for idx in [in_edges_idx, out_edges_idx]:
deg_inv = deg.pow(-0.5) # D^{-0.5} u, v = src[idx], dst[idx]
deg_inv[deg_inv == float('inf')] = 0 deg = th.zeros(graph.num_nodes()).to(graph.device)
norm = deg_inv[u] * deg_inv[v] n_idx, inverse_index, count = th.unique(
graph.edata['norm'][idx] = norm v, return_inverse=True, return_counts=True
graph.edata['norm'] = graph.edata['norm'].unsqueeze(1) )
deg[n_idx] = count.float()
return graph deg_inv = deg.pow(-0.5) # D^{-0.5}
deg_inv[deg_inv == float("inf")] = 0
norm = deg_inv[u] * deg_inv[v]
graph.edata["norm"][idx] = norm
graph.edata["norm"] = graph.edata["norm"].unsqueeze(1)
return graph
This diff is collapsed.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
...@@ -9,7 +10,7 @@ class MLPLinear(nn.Module): ...@@ -9,7 +10,7 @@ class MLPLinear(nn.Module):
super(MLPLinear, self).__init__() super(MLPLinear, self).__init__()
self.linear = nn.Linear(in_dim, out_dim) self.linear = nn.Linear(in_dim, out_dim)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.linear.reset_parameters() self.linear.reset_parameters()
...@@ -18,7 +19,7 @@ class MLPLinear(nn.Module): ...@@ -18,7 +19,7 @@ class MLPLinear(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.): def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0):
super(MLP, self).__init__() super(MLP, self).__init__()
assert num_layers >= 2 assert num_layers >= 2
...@@ -30,7 +31,7 @@ class MLP(nn.Module): ...@@ -30,7 +31,7 @@ class MLP(nn.Module):
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.linears.append(nn.Linear(hid_dim, hid_dim)) self.linears.append(nn.Linear(hid_dim, hid_dim))
self.bns.append(nn.BatchNorm1d(hid_dim)) self.bns.append(nn.BatchNorm1d(hid_dim))
self.linears.append(nn.Linear(hid_dim, out_dim)) self.linears.append(nn.Linear(hid_dim, out_dim))
self.dropout = dropout self.dropout = dropout
self.reset_parameters() self.reset_parameters()
...@@ -75,42 +76,49 @@ class LabelPropagation(nn.Module): ...@@ -75,42 +76,49 @@ class LabelPropagation(nn.Module):
'DA': D^-1 * A 'DA': D^-1 * A
'AD': A * D^-1 'AD': A * D^-1
""" """
def __init__(self, num_layers, alpha, adj='DAD'):
def __init__(self, num_layers, alpha, adj="DAD"):
super(LabelPropagation, self).__init__() super(LabelPropagation, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.alpha = alpha self.alpha = alpha
self.adj = adj self.adj = adj
@torch.no_grad() @torch.no_grad()
def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0., 1.)): def forward(
self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)
):
with g.local_scope(): with g.local_scope():
if labels.dtype == torch.long: if labels.dtype == torch.long:
labels = F.one_hot(labels.view(-1)).to(torch.float32) labels = F.one_hot(labels.view(-1)).to(torch.float32)
y = labels y = labels
if mask is not None: if mask is not None:
y = torch.zeros_like(labels) y = torch.zeros_like(labels)
y[mask] = labels[mask] y[mask] = labels[mask]
last = (1 - self.alpha) * y last = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1) degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5 if self.adj == 'DAD' else -1).to(labels.device).unsqueeze(1) norm = (
torch.pow(degs, -0.5 if self.adj == "DAD" else -1)
.to(labels.device)
.unsqueeze(1)
)
for _ in range(self.num_layers): for _ in range(self.num_layers):
# Assume the graphs to be undirected # Assume the graphs to be undirected
if self.adj in ['DAD', 'AD']: if self.adj in ["DAD", "AD"]:
y = norm * y y = norm * y
g.ndata['h'] = y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = self.alpha * g.ndata.pop('h')
if self.adj in ['DAD', 'DA']: g.ndata["h"] = y
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
y = self.alpha * g.ndata.pop("h")
if self.adj in ["DAD", "DA"]:
y = y * norm y = y * norm
y = post_step(last + y) y = post_step(last + y)
return y return y
...@@ -144,41 +152,50 @@ class CorrectAndSmooth(nn.Module): ...@@ -144,41 +152,50 @@ class CorrectAndSmooth(nn.Module):
scale: float, optional scale: float, optional
The scaling factor :math:`\sigma`, in case :obj:`autoscale = False`. Default is 1. The scaling factor :math:`\sigma`, in case :obj:`autoscale = False`. Default is 1.
""" """
def __init__(self,
num_correction_layers, def __init__(
correction_alpha, self,
correction_adj, num_correction_layers,
num_smoothing_layers, correction_alpha,
smoothing_alpha, correction_adj,
smoothing_adj, num_smoothing_layers,
autoscale=True, smoothing_alpha,
scale=1.): smoothing_adj,
autoscale=True,
scale=1.0,
):
super(CorrectAndSmooth, self).__init__() super(CorrectAndSmooth, self).__init__()
self.autoscale = autoscale self.autoscale = autoscale
self.scale = scale self.scale = scale
self.prop1 = LabelPropagation(num_correction_layers, self.prop1 = LabelPropagation(
correction_alpha, num_correction_layers, correction_alpha, correction_adj
correction_adj) )
self.prop2 = LabelPropagation(num_smoothing_layers, self.prop2 = LabelPropagation(
smoothing_alpha, num_smoothing_layers, smoothing_alpha, smoothing_adj
smoothing_adj) )
def correct(self, g, y_soft, y_true, mask): def correct(self, g, y_soft, y_true, mask):
with g.local_scope(): with g.local_scope():
assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2 assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) numel = (
int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
)
assert y_true.size(0) == numel assert y_true.size(0) == numel
if y_true.dtype == torch.long: if y_true.dtype == torch.long:
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(y_soft.dtype) y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(
y_soft.dtype
)
error = torch.zeros_like(y_soft) error = torch.zeros_like(y_soft)
error[mask] = y_true - y_soft[mask] error[mask] = y_true - y_soft[mask]
if self.autoscale: if self.autoscale:
smoothed_error = self.prop1(g, error, post_step=lambda x: x.clamp_(-1., 1.)) smoothed_error = self.prop1(
g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)
)
sigma = error[mask].abs().sum() / numel sigma = error[mask].abs().sum() / numel
scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)
scale[scale.isinf() | (scale > 1000)] = 1.0 scale[scale.isinf() | (scale > 1000)] = 1.0
...@@ -187,10 +204,11 @@ class CorrectAndSmooth(nn.Module): ...@@ -187,10 +204,11 @@ class CorrectAndSmooth(nn.Module):
result[result.isnan()] = y_soft[result.isnan()] result[result.isnan()] = y_soft[result.isnan()]
return result return result
else: else:
def fix_input(x): def fix_input(x):
x[mask] = error[mask] x[mask] = error[mask]
return x return x
smoothed_error = self.prop1(g, error, post_step=fix_input) smoothed_error = self.prop1(g, error, post_step=fix_input)
result = y_soft + self.scale * smoothed_error result = y_soft + self.scale * smoothed_error
...@@ -199,11 +217,15 @@ class CorrectAndSmooth(nn.Module): ...@@ -199,11 +217,15 @@ class CorrectAndSmooth(nn.Module):
def smooth(self, g, y_soft, y_true, mask): def smooth(self, g, y_soft, y_true, mask):
with g.local_scope(): with g.local_scope():
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) numel = (
int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
)
assert y_true.size(0) == numel assert y_true.size(0) == numel
if y_true.dtype == torch.long: if y_true.dtype == torch.long:
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(y_soft.dtype) y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(
y_soft.dtype
)
y_soft[mask] = y_true y_soft[mask] = y_true
return self.prop2(g, y_soft) return self.prop2(g, y_soft)
This diff is collapsed.
import numpy as np
import random import random
from torch.nn import functional as F
import numpy as np
import torch import torch
from torch.nn import functional as F
def evaluate(model, graph, feats, labels, idxs): def evaluate(model, graph, feats, labels, idxs):
...@@ -11,7 +12,9 @@ def evaluate(model, graph, feats, labels, idxs): ...@@ -11,7 +12,9 @@ def evaluate(model, graph, feats, labels, idxs):
results = () results = ()
for idx in idxs: for idx in idxs:
loss = F.cross_entropy(logits[idx], labels[idx]) loss = F.cross_entropy(logits[idx], labels[idx])
acc = torch.sum(logits[idx].argmax(dim=1) == labels[idx]).item() / len(idx) acc = torch.sum(
logits[idx].argmax(dim=1) == labels[idx]
).item() / len(idx)
results += (loss, acc) results += (loss, acc)
return results return results
...@@ -27,4 +30,4 @@ def set_random_state(seed): ...@@ -27,4 +30,4 @@ def set_random_state(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment