Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import torch import torch
def shuffle_walks(walks): def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0]) seeds = torch.randperm(walks.size()[0])
return walks[seeds] return walks[seeds]
def sum_up_params(model): def sum_up_params(model):
""" Count the model parameters """ """Count the model parameters"""
n = [] n = []
n.append(model.u_embeddings.weight.cpu().data.numel() * 2) n.append(model.u_embeddings.weight.cpu().data.numel() * 2)
n.append(model.lookup_table.cpu().numel()) n.append(model.lookup_table.cpu().numel())
......
from ogb.graphproppred import Evaluator import argparse
import torch
import numpy as np
from dgl.dataloading import GraphDataLoader
from tqdm import tqdm
import dgl
import random import random
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
from ogb.graphproppred.mol_encoder import AtomEncoder
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import argparse from ogb.graphproppred import Evaluator
from torch.utils.data import Dataset from ogb.graphproppred.mol_encoder import AtomEncoder
from preprocessing import prepare_dataset from preprocessing import prepare_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
import dgl
from dgl.dataloading import GraphDataLoader
def aggregate_mean(h, vector_field, h_in): def aggregate_mean(h, vector_field, h_in):
return torch.mean(h, dim=1) return torch.mean(h, dim=1)
def aggregate_max(h, vector_field, h_in): def aggregate_max(h, vector_field, h_in):
return torch.max(h, dim=1)[0] return torch.max(h, dim=1)[0]
def aggregate_sum(h, vector_field, h_in): def aggregate_sum(h, vector_field, h_in):
return torch.sum(h, dim=1) return torch.sum(h, dim=1)
def aggregate_dir_dx(h, vector_field, h_in, eig_idx=1): def aggregate_dir_dx(h, vector_field, h_in, eig_idx=1):
eig_w = ((vector_field[:, :, eig_idx]) / eig_w = (
(torch.sum(torch.abs(vector_field[:, :, eig_idx]), keepdim=True, dim=1) + 1e-8)).unsqueeze(-1) (vector_field[:, :, eig_idx])
/ (
torch.sum(
torch.abs(vector_field[:, :, eig_idx]), keepdim=True, dim=1
)
+ 1e-8
)
).unsqueeze(-1)
h_mod = torch.mul(h, eig_w) h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in) return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
class FCLayer(nn.Module): class FCLayer(nn.Module):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
super(FCLayer, self).__init__() super(FCLayer, self).__init__()
...@@ -46,6 +60,7 @@ class FCLayer(nn.Module): ...@@ -46,6 +60,7 @@ class FCLayer(nn.Module):
h = self.linear(x) h = self.linear(x)
return h return h
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
super(MLP, self).__init__() super(MLP, self).__init__()
...@@ -58,6 +73,7 @@ class MLP(nn.Module): ...@@ -58,6 +73,7 @@ class MLP(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
class DGNLayer(nn.Module): class DGNLayer(nn.Module):
def __init__(self, in_dim, out_dim, dropout, aggregators): def __init__(self, in_dim, out_dim, dropout, aggregators):
super().__init__() super().__init__()
...@@ -68,36 +84,47 @@ class DGNLayer(nn.Module): ...@@ -68,36 +84,47 @@ class DGNLayer(nn.Module):
self.batchnorm_h = nn.BatchNorm1d(out_dim) self.batchnorm_h = nn.BatchNorm1d(out_dim)
self.pretrans = MLP(in_size=2 * in_dim, out_size=in_dim) self.pretrans = MLP(in_size=2 * in_dim, out_size=in_dim)
self.posttrans = MLP(in_size=(len(aggregators) * 1 + 1) * in_dim, out_size=out_dim) self.posttrans = MLP(
in_size=(len(aggregators) * 1 + 1) * in_dim, out_size=out_dim
)
def pretrans_edges(self, edges): def pretrans_edges(self, edges):
z2 = torch.cat([edges.src['h'], edges.dst['h']], dim=1) z2 = torch.cat([edges.src["h"], edges.dst["h"]], dim=1)
vector_field = edges.data['eig'] vector_field = edges.data["eig"]
return {'e': self.pretrans(z2), 'vector_field': vector_field} return {"e": self.pretrans(z2), "vector_field": vector_field}
def message_func(self, edges): def message_func(self, edges):
return {'e': edges.data['e'], 'vector_field': edges.data['vector_field']} return {
"e": edges.data["e"],
"vector_field": edges.data["vector_field"],
}
def reduce_func(self, nodes): def reduce_func(self, nodes):
h_in = nodes.data['h'] h_in = nodes.data["h"]
h = nodes.mailbox['e'] h = nodes.mailbox["e"]
vector_field = nodes.mailbox['vector_field'] vector_field = nodes.mailbox["vector_field"]
h = torch.cat([aggregate(h, vector_field, h_in) for aggregate in self.aggregators], dim=1) h = torch.cat(
[
aggregate(h, vector_field, h_in)
for aggregate in self.aggregators
],
dim=1,
)
return {'h': h} return {"h": h}
def forward(self, g, h, snorm_n): def forward(self, g, h, snorm_n):
g.ndata['h'] = h g.ndata["h"] = h
# pretransformation # pretransformation
g.apply_edges(self.pretrans_edges) g.apply_edges(self.pretrans_edges)
# aggregation # aggregation
g.update_all(self.message_func, self.reduce_func) g.update_all(self.message_func, self.reduce_func)
h = torch.cat([h, g.ndata['h']], dim=1) h = torch.cat([h, g.ndata["h"]], dim=1)
# posttransformation # posttransformation
h = self.posttrans(h) h = self.posttrans(h)
...@@ -111,12 +138,17 @@ class DGNLayer(nn.Module): ...@@ -111,12 +138,17 @@ class DGNLayer(nn.Module):
return h return h
class MLPReadout(nn.Module):
class MLPReadout(nn.Module):
def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers
super().__init__() super().__init__()
list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)] list_FC_layers = [
list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True)) nn.Linear(input_dim // 2**l, input_dim // 2 ** (l + 1), bias=True)
for l in range(L)
]
list_FC_layers.append(
nn.Linear(input_dim // 2**L, output_dim, bias=True)
)
self.FC_layers = nn.ModuleList(list_FC_layers) self.FC_layers = nn.ModuleList(list_FC_layers)
self.L = L self.L = L
...@@ -128,17 +160,38 @@ class MLPReadout(nn.Module): ...@@ -128,17 +160,38 @@ class MLPReadout(nn.Module):
y = self.FC_layers[self.L](y) y = self.FC_layers[self.L](y)
return y return y
class DGNNet(nn.Module): class DGNNet(nn.Module):
def __init__(self, hidden_dim=420, out_dim=420, dropout=0.2, n_layers=4): def __init__(self, hidden_dim=420, out_dim=420, dropout=0.2, n_layers=4):
super().__init__() super().__init__()
self.embedding_h = AtomEncoder(emb_dim=hidden_dim) self.embedding_h = AtomEncoder(emb_dim=hidden_dim)
self.aggregators = [aggregate_mean, aggregate_sum, aggregate_max, aggregate_dir_dx] self.aggregators = [
aggregate_mean,
self.layers = nn.ModuleList([DGNLayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, aggregate_sum,
aggregators=self.aggregators) for _ in range(n_layers - 1)]) aggregate_max,
self.layers.append(DGNLayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, aggregate_dir_dx,
aggregators=self.aggregators)) ]
self.layers = nn.ModuleList(
[
DGNLayer(
in_dim=hidden_dim,
out_dim=hidden_dim,
dropout=dropout,
aggregators=self.aggregators,
)
for _ in range(n_layers - 1)
]
)
self.layers.append(
DGNLayer(
in_dim=hidden_dim,
out_dim=out_dim,
dropout=dropout,
aggregators=self.aggregators,
)
)
# 128 out dim since ogbg-molpcba has 128 tasks # 128 out dim since ogbg-molpcba has 128 tasks
self.MLP_layer = MLPReadout(out_dim, 128) self.MLP_layer = MLPReadout(out_dim, 128)
...@@ -150,32 +203,37 @@ class DGNNet(nn.Module): ...@@ -150,32 +203,37 @@ class DGNNet(nn.Module):
h_t = conv(g, h, snorm_n) h_t = conv(g, h, snorm_n)
h = h_t h = h_t
g.ndata['h'] = h g.ndata["h"] = h
hg = dgl.mean_nodes(g, 'h') hg = dgl.mean_nodes(g, "h")
return self.MLP_layer(hg) return self.MLP_layer(hg)
def loss(self, scores, labels): def loss(self, scores, labels):
is_labeled = labels == labels is_labeled = labels == labels
loss = nn.BCEWithLogitsLoss()(scores[is_labeled], labels[is_labeled].float()) loss = nn.BCEWithLogitsLoss()(
scores[is_labeled], labels[is_labeled].float()
)
return loss return loss
def train_epoch(model, optimizer, device, data_loader): def train_epoch(model, optimizer, device, data_loader):
model.train() model.train()
epoch_loss = 0 epoch_loss = 0
epoch_train_AP = 0 epoch_train_AP = 0
list_scores = [] list_scores = []
list_labels = [] list_labels = []
for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader): for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
data_loader
):
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_x = batch_graphs.ndata['feat'] # num x feat batch_x = batch_graphs.ndata["feat"] # num x feat
batch_snorm_n = batch_snorm_n.to(device) batch_snorm_n = batch_snorm_n.to(device)
batch_labels = batch_labels.to(device) batch_labels = batch_labels.to(device)
optimizer.zero_grad() optimizer.zero_grad()
batch_scores = model(batch_graphs, batch_x, batch_snorm_n) batch_scores = model(batch_graphs, batch_x, batch_snorm_n)
loss = model.loss(batch_scores, batch_labels) loss = model.loss(batch_scores, batch_labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -183,14 +241,16 @@ def train_epoch(model, optimizer, device, data_loader): ...@@ -183,14 +241,16 @@ def train_epoch(model, optimizer, device, data_loader):
list_scores.append(batch_scores) list_scores.append(batch_scores)
list_labels.append(batch_labels) list_labels.append(batch_labels)
epoch_loss /= (iter + 1) epoch_loss /= iter + 1
evaluator = Evaluator(name='ogbg-molpcba') evaluator = Evaluator(name="ogbg-molpcba")
epoch_train_AP = evaluator.eval({'y_pred': torch.cat(list_scores), epoch_train_AP = evaluator.eval(
'y_true': torch.cat(list_labels)})['ap'] {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
)["ap"]
return epoch_loss, epoch_train_AP return epoch_loss, epoch_train_AP
def evaluate_network(model, device, data_loader): def evaluate_network(model, device, data_loader):
model.eval() model.eval()
epoch_test_loss = 0 epoch_test_loss = 0
...@@ -198,9 +258,11 @@ def evaluate_network(model, device, data_loader): ...@@ -198,9 +258,11 @@ def evaluate_network(model, device, data_loader):
with torch.no_grad(): with torch.no_grad():
list_scores = [] list_scores = []
list_labels = [] list_labels = []
for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader): for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
data_loader
):
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_x = batch_graphs.ndata['feat'] batch_x = batch_graphs.ndata["feat"]
batch_snorm_n = batch_snorm_n.to(device) batch_snorm_n = batch_snorm_n.to(device)
batch_labels = batch_labels.to(device) batch_labels = batch_labels.to(device)
...@@ -211,14 +273,16 @@ def evaluate_network(model, device, data_loader): ...@@ -211,14 +273,16 @@ def evaluate_network(model, device, data_loader):
list_scores.append(batch_scores) list_scores.append(batch_scores)
list_labels.append(batch_labels) list_labels.append(batch_labels)
epoch_test_loss /= (iter + 1) epoch_test_loss /= iter + 1
evaluator = Evaluator(name='ogbg-molpcba') evaluator = Evaluator(name="ogbg-molpcba")
epoch_test_AP = evaluator.eval({'y_pred': torch.cat(list_scores), epoch_test_AP = evaluator.eval(
'y_true': torch.cat(list_labels)})['ap'] {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
)["ap"]
return epoch_test_loss, epoch_test_AP return epoch_test_loss, epoch_test_AP
def train(dataset, params): def train(dataset, params):
trainset, valset, testset = dataset.train, dataset.val, dataset.test trainset, valset, testset = dataset.train, dataset.val, dataset.test
...@@ -236,27 +300,48 @@ def train(dataset, params): ...@@ -236,27 +300,48 @@ def train(dataset, params):
print("MODEL DETAILS:\n") print("MODEL DETAILS:\n")
for param in model.parameters(): for param in model.parameters():
total_param += np.prod(list(param.data.size())) total_param += np.prod(list(param.data.size()))
print('DGN Total parameters:', total_param) print("DGN Total parameters:", total_param)
optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5) optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', scheduler = optim.lr_scheduler.ReduceLROnPlateau(
factor=0.8, optimizer, mode="min", factor=0.8, patience=8, verbose=True
patience=8, )
verbose=True)
epoch_train_losses, epoch_val_losses = [], [] epoch_train_losses, epoch_val_losses = [], []
epoch_train_APs, epoch_val_APs, epoch_test_APs = [], [], [] epoch_train_APs, epoch_val_APs, epoch_test_APs = [], [], []
train_loader = GraphDataLoader(trainset, batch_size=params.batch_size, shuffle=True, collate_fn=dataset.collate, pin_memory=True) train_loader = GraphDataLoader(
val_loader = GraphDataLoader(valset, batch_size=params.batch_size, shuffle=False, collate_fn=dataset.collate, pin_memory=True) trainset,
test_loader = GraphDataLoader(testset, batch_size=params.batch_size, shuffle=False, collate_fn=dataset.collate, pin_memory=True) batch_size=params.batch_size,
shuffle=True,
with tqdm(range(450), unit='epoch') as t: collate_fn=dataset.collate,
pin_memory=True,
)
val_loader = GraphDataLoader(
valset,
batch_size=params.batch_size,
shuffle=False,
collate_fn=dataset.collate,
pin_memory=True,
)
test_loader = GraphDataLoader(
testset,
batch_size=params.batch_size,
shuffle=False,
collate_fn=dataset.collate,
pin_memory=True,
)
with tqdm(range(450), unit="epoch") as t:
for epoch in t: for epoch in t:
t.set_description('Epoch %d' % epoch) t.set_description("Epoch %d" % epoch)
epoch_train_loss, epoch_train_ap = train_epoch(model, optimizer, device, train_loader) epoch_train_loss, epoch_train_ap = train_epoch(
epoch_val_loss, epoch_val_ap = evaluate_network(model, device, val_loader) model, optimizer, device, train_loader
)
epoch_val_loss, epoch_val_ap = evaluate_network(
model, device, val_loader
)
epoch_train_losses.append(epoch_train_loss) epoch_train_losses.append(epoch_train_loss)
epoch_val_losses.append(epoch_val_loss) epoch_val_losses.append(epoch_val_loss)
...@@ -267,17 +352,20 @@ def train(dataset, params): ...@@ -267,17 +352,20 @@ def train(dataset, params):
epoch_test_APs.append(epoch_test_ap.item()) epoch_test_APs.append(epoch_test_ap.item())
t.set_postfix(train_loss=epoch_train_loss, t.set_postfix(
train_AP=epoch_train_ap.item(), val_AP=epoch_val_ap.item(), train_loss=epoch_train_loss,
refresh=False) train_AP=epoch_train_ap.item(),
val_AP=epoch_val_ap.item(),
refresh=False,
)
scheduler.step(-epoch_val_ap.item()) scheduler.step(-epoch_val_ap.item())
if optimizer.param_groups[0]['lr'] < 1e-5: if optimizer.param_groups[0]["lr"] < 1e-5:
print("\n!! LR EQUAL TO MIN LR SET.") print("\n!! LR EQUAL TO MIN LR SET.")
break break
print('') print("")
best_val_epoch = np.argmax(np.array(epoch_val_APs)) best_val_epoch = np.argmax(np.array(epoch_val_APs))
best_train_epoch = np.argmax(np.array(epoch_train_APs)) best_train_epoch = np.argmax(np.array(epoch_train_APs))
...@@ -291,6 +379,7 @@ def train(dataset, params): ...@@ -291,6 +379,7 @@ def train(dataset, params):
print("Test AP of Best Val: {:.4f}".format(best_val_test_ap)) print("Test AP of Best Val: {:.4f}".format(best_val_test_ap))
print("Train AP of Best Val: {:.4f}".format(best_val_train_ap)) print("Train AP of Best Val: {:.4f}".format(best_val_train_ap))
class Subset(object): class Subset(object):
def __init__(self, dataset, labels, indices): def __init__(self, dataset, labels, indices):
dataset = [dataset[idx] for idx in indices] dataset = [dataset[idx] for idx in indices]
...@@ -308,23 +397,35 @@ class Subset(object): ...@@ -308,23 +397,35 @@ class Subset(object):
def __len__(self): def __len__(self):
return self.len return self.len
class PCBADataset(Dataset): class PCBADataset(Dataset):
def __init__(self, name): def __init__(self, name):
print("[I] Loading dataset %s..." % (name)) print("[I] Loading dataset %s..." % (name))
self.name = name self.name = name
self.dataset, self.split_idx = prepare_dataset(name) self.dataset, self.split_idx = prepare_dataset(name)
print("One hot encoding substructure counts... ", end='') print("One hot encoding substructure counts... ", end="")
self.d_id = [1]*self.dataset[0].edata['subgraph_counts'].shape[1] self.d_id = [1] * self.dataset[0].edata["subgraph_counts"].shape[1]
for g in self.dataset: for g in self.dataset:
g.edata['eig'] = g.edata['subgraph_counts'].float() g.edata["eig"] = g.edata["subgraph_counts"].float()
self.train = Subset(self.dataset, self.split_idx['label'], self.split_idx['train']) self.train = Subset(
self.val = Subset(self.dataset, self.split_idx['label'], self.split_idx['valid']) self.dataset, self.split_idx["label"], self.split_idx["train"]
self.test = Subset(self.dataset, self.split_idx['label'], self.split_idx['test']) )
self.val = Subset(
print('train, test, val sizes :', len(self.train), len(self.test), len(self.val)) self.dataset, self.split_idx["label"], self.split_idx["valid"]
)
self.test = Subset(
self.dataset, self.split_idx["label"], self.split_idx["test"]
)
print(
"train, test, val sizes :",
len(self.train),
len(self.test),
len(self.val),
)
print("[I] Finished loading.") print("[I] Finished loading.")
# form a mini batch from a given list of samples = [(graph, label) pairs] # form a mini batch from a given list of samples = [(graph, label) pairs]
...@@ -334,22 +435,36 @@ class PCBADataset(Dataset): ...@@ -334,22 +435,36 @@ class PCBADataset(Dataset):
labels = torch.stack(labels) labels = torch.stack(labels)
tab_sizes_n = [g.num_nodes() for g in graphs] tab_sizes_n = [g.num_nodes() for g in graphs]
tab_snorm_n = [torch.FloatTensor(size, 1).fill_(1./size) for size in tab_sizes_n] tab_snorm_n = [
torch.FloatTensor(size, 1).fill_(1.0 / size) for size in tab_sizes_n
]
snorm_n = torch.cat(tab_snorm_n).sqrt() snorm_n = torch.cat(tab_snorm_n).sqrt()
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
return batched_graph, labels, snorm_n return batched_graph, labels, snorm_n
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', default=0, type=int, help="Please give a value for gpu id") parser.add_argument(
parser.add_argument('--seed', default=41, type=int, help="Please give a value for seed") "--gpu_id", default=0, type=int, help="Please give a value for gpu id"
parser.add_argument('--batch_size', default=2048, type=int, help="Please give a value for batch_size") )
parser.add_argument(
"--seed", default=41, type=int, help="Please give a value for seed"
)
parser.add_argument(
"--batch_size",
default=2048,
type=int,
help="Please give a value for batch_size",
)
args = parser.parse_args() args = parser.parse_args()
# device # device
args.device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu") args.device = torch.device(
"cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu"
)
# setting seeds # setting seeds
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -358,4 +473,4 @@ if __name__ == '__main__': ...@@ -358,4 +473,4 @@ if __name__ == '__main__':
torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed(args.seed)
dataset = PCBADataset("ogbg-molpcba") dataset = PCBADataset("ogbg-molpcba")
train(dataset, args) train(dataset, args)
\ No newline at end of file
from ogb.graphproppred import DglGraphPropPredDataset import os
import torch
import numpy as np
import networkx as nx
import graph_tool as gt import graph_tool as gt
import graph_tool.topology as gt_topology import graph_tool.topology as gt_topology
import networkx as nx
import numpy as np
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from tqdm import tqdm from tqdm import tqdm
import os
from dgl.data.utils import save_graphs, load_graphs from dgl.data.utils import load_graphs, save_graphs
def to_undirected(edge_index): def to_undirected(edge_index):
row, col = edge_index.transpose(1,0) row, col = edge_index.transpose(1, 0)
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
return edge_index.transpose(1,0).tolist() return edge_index.transpose(1, 0).tolist()
def induced_edge_automorphism_orbits(edge_list): def induced_edge_automorphism_orbits(edge_list):
##### node automorphism orbits ##### ##### node automorphism orbits #####
graph = gt.Graph(directed=False) graph = gt.Graph(directed=False)
graph.add_edge_list(edge_list) graph.add_edge_list(edge_list)
gt.stats.remove_self_loops(graph) gt.stats.remove_self_loops(graph)
gt.stats.remove_parallel_edges(graph) gt.stats.remove_parallel_edges(graph)
# compute the node automorphism group # compute the node automorphism group
aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False) aut_group = gt_topology.subgraph_isomorphism(
graph, graph, induced=False, subgraph=True, generator=False
)
orbit_membership = {} orbit_membership = {}
for v in graph.get_vertices(): for v in graph.get_vertices():
orbit_membership[v] = v orbit_membership[v] = v
# whenever two nodes can be mapped via some automorphism, they are assigned the same orbit # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit
for aut in aut_group: for aut in aut_group:
for original, node in enumerate(aut): for original, node in enumerate(aut):
role = min(original, orbit_membership[node]) role = min(original, orbit_membership[node])
orbit_membership[node] = role orbit_membership[node] = role
orbit_membership_list = [[],[]] orbit_membership_list = [[], []]
for node, om_curr in orbit_membership.items(): for node, om_curr in orbit_membership.items():
orbit_membership_list[0].append(node) orbit_membership_list[0].append(node)
orbit_membership_list[1].append(om_curr) orbit_membership_list[1].append(om_curr)
# make orbit list contiguous (i.e. 0,1,2,...O) # make orbit list contiguous (i.e. 0,1,2,...O)
_, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True) _, contiguous_orbit_membership = np.unique(
orbit_membership_list[1], return_inverse=True
)
orbit_membership = {node: contiguous_orbit_membership[i] for i,node in enumerate(orbit_membership_list[0])} orbit_membership = {
node: contiguous_orbit_membership[i]
for i, node in enumerate(orbit_membership_list[0])
}
aut_count = len(aut_group) aut_count = len(aut_group)
...@@ -53,12 +64,14 @@ def induced_edge_automorphism_orbits(edge_list): ...@@ -53,12 +64,14 @@ def induced_edge_automorphism_orbits(edge_list):
edge_orbit_membership = dict() edge_orbit_membership = dict()
edge_orbits2inds = dict() edge_orbits2inds = dict()
ind = 0 ind = 0
edge_list = to_undirected(torch.tensor(graph.get_edges())) edge_list = to_undirected(torch.tensor(graph.get_edges()))
# infer edge automorphisms from the node automorphisms # infer edge automorphisms from the node automorphisms
for i,edge in enumerate(edge_list): for i, edge in enumerate(edge_list):
edge_orbit = frozenset([orbit_membership[edge[0]], orbit_membership[edge[1]]]) edge_orbit = frozenset(
[orbit_membership[edge[0]], orbit_membership[edge[1]]]
)
if edge_orbit not in edge_orbits2inds: if edge_orbit not in edge_orbits2inds:
edge_orbits2inds[edge_orbit] = ind edge_orbits2inds[edge_orbit] = ind
ind_edge_orbit = ind ind_edge_orbit = ind
...@@ -69,78 +82,97 @@ def induced_edge_automorphism_orbits(edge_list): ...@@ -69,78 +82,97 @@ def induced_edge_automorphism_orbits(edge_list):
if ind_edge_orbit not in edge_orbit_partition: if ind_edge_orbit not in edge_orbit_partition:
edge_orbit_partition[ind_edge_orbit] = [tuple(edge)] edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]
else: else:
edge_orbit_partition[ind_edge_orbit] += [tuple(edge)] edge_orbit_partition[ind_edge_orbit] += [tuple(edge)]
edge_orbit_membership[i] = ind_edge_orbit edge_orbit_membership[i] = ind_edge_orbit
print('Edge orbit partition of given substructure: {}'.format(edge_orbit_partition)) print(
print('Number of edge orbits: {}'.format(len(edge_orbit_partition))) "Edge orbit partition of given substructure: {}".format(
print('Graph (node) automorphism count: {}'.format(aut_count)) edge_orbit_partition
)
)
print("Number of edge orbits: {}".format(len(edge_orbit_partition)))
print("Graph (node) automorphism count: {}".format(aut_count))
return graph, edge_orbit_partition, edge_orbit_membership, aut_count return graph, edge_orbit_partition, edge_orbit_membership, aut_count
def subgraph_isomorphism_edge_counts(edge_index, subgraph_dict): def subgraph_isomorphism_edge_counts(edge_index, subgraph_dict):
##### edge structural identifiers ##### ##### edge structural identifiers #####
edge_index = edge_index.transpose(1,0).cpu().numpy() edge_index = edge_index.transpose(1, 0).cpu().numpy()
edge_dict = {} edge_dict = {}
for i, edge in enumerate(edge_index): for i, edge in enumerate(edge_index):
edge_dict[tuple(edge)] = i edge_dict[tuple(edge)] = i
subgraph_edges = to_undirected(torch.tensor(subgraph_dict['subgraph'].get_edges().tolist())) subgraph_edges = to_undirected(
torch.tensor(subgraph_dict["subgraph"].get_edges().tolist())
)
G_gt = gt.Graph(directed=False) G_gt = gt.Graph(directed=False)
G_gt.add_edge_list(list(edge_index)) G_gt.add_edge_list(list(edge_index))
gt.stats.remove_self_loops(G_gt) gt.stats.remove_self_loops(G_gt)
gt.stats.remove_parallel_edges(G_gt) gt.stats.remove_parallel_edges(G_gt)
# compute all subgraph isomorphisms # compute all subgraph isomorphisms
sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=True, subgraph=True, generator=True) sub_iso = gt_topology.subgraph_isomorphism(
subgraph_dict["subgraph"],
counts = np.zeros((edge_index.shape[0], len(subgraph_dict['orbit_partition']))) G_gt,
induced=True,
subgraph=True,
generator=True,
)
counts = np.zeros(
(edge_index.shape[0], len(subgraph_dict["orbit_partition"]))
)
for sub_iso_curr in sub_iso: for sub_iso_curr in sub_iso:
mapping = sub_iso_curr.get_array() mapping = sub_iso_curr.get_array()
for i,edge in enumerate(subgraph_edges): for i, edge in enumerate(subgraph_edges):
# for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped
# (by finding where its endpoints are matched). # (by finding where its endpoints are matched).
# Then, increase the count of the matched edge w.r.t. the corresponding orbit # Then, increase the count of the matched edge w.r.t. the corresponding orbit
# Repeat for the reverse edge (the one with the opposite direction) # Repeat for the reverse edge (the one with the opposite direction)
edge_orbit = subgraph_dict['orbit_membership'][i] edge_orbit = subgraph_dict["orbit_membership"][i]
mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]]) mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])
counts[edge_dict[mapped_edge], edge_orbit] += 1 counts[edge_dict[mapped_edge], edge_orbit] += 1
counts = counts/subgraph_dict['aut_count'] counts = counts / subgraph_dict["aut_count"]
counts = torch.tensor(counts) counts = torch.tensor(counts)
return counts return counts
def prepare_dataset(name): def prepare_dataset(name):
# maximum size of cycle graph # maximum size of cycle graph
k = 8 k = 8
path = os.path.join('./', 'dataset', name) path = os.path.join("./", "dataset", name)
data_folder = os.path.join(path, 'processed') data_folder = os.path.join(path, "processed")
os.makedirs(data_folder, exist_ok=True) os.makedirs(data_folder, exist_ok=True)
data_file = os.path.join(data_folder, 'cycle_graph_induced_{}.bin'.format(k)) data_file = os.path.join(
data_folder, "cycle_graph_induced_{}.bin".format(k)
)
# try to load # try to load
if os.path.exists(data_file): # load if os.path.exists(data_file): # load
print("Loading dataset from {}".format(data_file)) print("Loading dataset from {}".format(data_file))
g_list, split_idx = load_graphs(data_file) g_list, split_idx = load_graphs(data_file)
else: # generate else: # generate
g_list, split_idx = generate_dataset(path, name) g_list, split_idx = generate_dataset(path, name)
print("Saving dataset to {}".format(data_file)) print("Saving dataset to {}".format(data_file))
save_graphs(data_file, g_list, split_idx) save_graphs(data_file, g_list, split_idx)
return g_list, split_idx return g_list, split_idx
def generate_dataset(path, name): def generate_dataset(path, name):
### compute the orbits of each substructure in the list, as well as the node automorphism count ### compute the orbits of each substructure in the list, as well as the node automorphism count
...@@ -152,14 +184,25 @@ def generate_dataset(path, name): ...@@ -152,14 +184,25 @@ def generate_dataset(path, name):
edge_lists.append(list(graphs_nx.edges)) edge_lists.append(list(graphs_nx.edges))
for edge_list in edge_lists: for edge_list in edge_lists:
subgraph, orbit_partition, orbit_membership, aut_count = induced_edge_automorphism_orbits(edge_list=edge_list) (
subgraph_dicts.append({'subgraph':subgraph, 'orbit_partition': orbit_partition, subgraph,
'orbit_membership': orbit_membership, 'aut_count': aut_count}) orbit_partition,
orbit_membership,
aut_count,
) = induced_edge_automorphism_orbits(edge_list=edge_list)
subgraph_dicts.append(
{
"subgraph": subgraph,
"orbit_partition": orbit_partition,
"orbit_membership": orbit_membership,
"aut_count": aut_count,
}
)
### load and preprocess dataset ### load and preprocess dataset
dataset = DglGraphPropPredDataset(name=name, root=path) dataset = DglGraphPropPredDataset(name=name, root=path)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
# computation of subgraph isomorphisms & creation of data structure # computation of subgraph isomorphisms & creation of data structure
graphs_dgl = list() graphs_dgl = list()
split_idx["label"] = [] split_idx["label"] = []
...@@ -172,19 +215,25 @@ def generate_dataset(path, name): ...@@ -172,19 +215,25 @@ def generate_dataset(path, name):
split_idx["label"] = torch.stack(split_idx["label"]) split_idx["label"] = torch.stack(split_idx["label"])
return graphs_dgl, split_idx return graphs_dgl, split_idx
def _prepare(g, subgraph_dicts): def _prepare(g, subgraph_dicts):
edge_index = torch.stack(g.edges()) edge_index = torch.stack(g.edges())
identifiers = None identifiers = None
for subgraph_dict in subgraph_dicts: for subgraph_dict in subgraph_dicts:
counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict) counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict)
identifiers = counts if identifiers is None else torch.cat((identifiers, counts),1) identifiers = (
counts
if identifiers is None
else torch.cat((identifiers, counts), 1)
)
g.edata["subgraph_counts"] = identifiers.long()
g.edata['subgraph_counts'] = identifiers.long()
return g return g
if __name__ == '__main__':
prepare_dataset("ogbg-molpcba") if __name__ == "__main__":
\ No newline at end of file prepare_dataset("ogbg-molpcba")
import torch
import argparse import argparse
import dgl
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os import os
import random import random
import time import time
import numpy as np
from reading_data import LineDataset import numpy as np
import torch
import torch.multiprocessing as mp
from model import SkipGramModel from model import SkipGramModel
from utils import sum_up_params, check_args from reading_data import LineDataset
from torch.utils.data import DataLoader
from utils import check_args, sum_up_params
import dgl
class LineTrainer: class LineTrainer:
def __init__(self, args): def __init__(self, args):
""" Initializing the trainer with the input arguments """ """Initializing the trainer with the input arguments"""
self.args = args self.args = args
self.dataset = LineDataset( self.dataset = LineDataset(
net_file=args.data_file, net_file=args.data_file,
...@@ -27,20 +29,22 @@ class LineTrainer: ...@@ -27,20 +29,22 @@ class LineTrainer:
ogbn_name=args.ogbn_name, ogbn_name=args.ogbn_name,
load_from_ogbn=args.load_from_ogbn, load_from_ogbn=args.load_from_ogbn,
num_samples=args.num_samples * 1000000, num_samples=args.num_samples * 1000000,
) )
self.emb_size = self.dataset.G.number_of_nodes() self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None self.emb_model = None
def init_device_emb(self): def init_device_emb(self):
""" set the device before training """set the device before training
will be called once in fast_train_mp / fast_train will be called once in fast_train_mp / fast_train
""" """
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix]) choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]" assert (
choices == 1
), "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
# initializing embedding on CPU # initializing embedding on CPU
self.emb_model = SkipGramModel( self.emb_model = SkipGramModel(
emb_size=self.emb_size, emb_size=self.emb_size,
emb_dimension=self.args.dim, emb_dimension=self.args.dim,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
only_cpu=self.args.only_cpu, only_cpu=self.args.only_cpu,
...@@ -56,8 +60,8 @@ class LineTrainer: ...@@ -56,8 +60,8 @@ class LineTrainer:
record_loss=self.args.print_loss, record_loss=self.args.print_loss,
async_update=self.args.async_update, async_update=self.args.async_update,
num_threads=self.args.num_threads, num_threads=self.args.num_threads,
) )
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.only_gpu: if self.args.only_gpu:
print("Run in 1 GPU") print("Run in 1 GPU")
...@@ -66,20 +70,22 @@ class LineTrainer: ...@@ -66,20 +70,22 @@ class LineTrainer:
elif self.args.mix: elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus)) print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1: if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have avaliable GPU' assert (
self.args.gpus[0] >= 0
), "mix CPU with GPU should have avaliable GPU"
self.emb_model.set_device(self.args.gpus[0]) self.emb_model.set_device(self.args.gpus[0])
else: else:
print("Run in CPU process") print("Run in CPU process")
def train(self): def train(self):
""" train the embedding """ """train the embedding"""
if len(self.args.gpus) > 1: if len(self.args.gpus) > 1:
self.fast_train_mp() self.fast_train_mp()
else: else:
self.fast_train() self.fast_train()
def fast_train_mp(self): def fast_train_mp(self):
""" multi-cpu-core or mix cpu & multi-gpu """ """multi-cpu-core or mix cpu & multi-gpu"""
self.init_device_emb() self.init_device_emb()
self.emb_model.share_memory() self.emb_model.share_memory()
...@@ -89,24 +95,30 @@ class LineTrainer: ...@@ -89,24 +95,30 @@ class LineTrainer:
ps = [] ps = []
for i in range(len(self.args.gpus)): for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i])) p = mp.Process(
target=self.fast_train_sp, args=(i, self.args.gpus[i])
)
ps.append(p) ps.append(p)
p.start() p.start()
for p in ps: for p in ps:
p.join() p.join()
print("Used time: %.2fs" % (time.time()-start_all)) print("Used time: %.2fs" % (time.time() - start_all))
if self.args.save_in_pt: if self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_pt(
self.dataset, self.args.output_emb_file
)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(
self.dataset, self.args.output_emb_file
)
def fast_train_sp(self, rank, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """a subprocess for fast_train_mp"""
if self.args.mix: if self.args.mix:
self.emb_model.set_device(gpu_id) self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.async_update: if self.args.async_update:
self.emb_model.create_async_update() self.emb_model.create_async_update()
...@@ -120,9 +132,12 @@ class LineTrainer: ...@@ -120,9 +132,12 @@ class LineTrainer:
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=self.args.num_sampler_threads, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id)) print(
"num batchs: %d in process [%d] GPU [%d]"
% (num_batches, rank, gpu_id)
)
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
...@@ -133,35 +148,65 @@ class LineTrainer: ...@@ -133,35 +148,65 @@ class LineTrainer:
# do negative sampling # do negative sampling
bs = edges.size()[0] bs = edges.size()[0]
neg_nodes = torch.LongTensor( neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table, np.random.choice(
bs * self.args.negative, self.dataset.neg_table,
replace=True)) bs * self.args.negative,
replace=True,
)
)
self.emb_model.fast_learn(edges, neg_nodes=neg_nodes) self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss: if self.args.print_loss:
if self.args.only_fst: if self.args.only_fst:
print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f" \ print(
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval)) "GPU-[%d] batch %d time: %.2fs fst-loss: %.4f"
% (
gpu_id,
i,
time.time() - start,
-sum(self.emb_model.loss_fst)
/ self.args.print_interval,
)
)
elif self.args.only_snd: elif self.args.only_snd:
print("GPU-[%d] batch %d time: %.2fs snd-loss: %.4f" \ print(
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval)) "GPU-[%d] batch %d time: %.2fs snd-loss: %.4f"
% (
gpu_id,
i,
time.time() - start,
-sum(self.emb_model.loss_snd)
/ self.args.print_interval,
)
)
else: else:
print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \ print(
% (gpu_id, i, time.time()-start, \ "GPU-[%d] batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f"
-sum(self.emb_model.loss_fst)/self.args.print_interval, \ % (
-sum(self.emb_model.loss_snd)/self.args.print_interval)) gpu_id,
i,
time.time() - start,
-sum(self.emb_model.loss_fst)
/ self.args.print_interval,
-sum(self.emb_model.loss_snd)
/ self.args.print_interval,
)
)
self.emb_model.loss_fst = [] self.emb_model.loss_fst = []
self.emb_model.loss_snd = [] self.emb_model.loss_snd = []
else: else:
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start)) print(
"GPU-[%d] batch %d time: %.2fs"
% (gpu_id, i, time.time() - start)
)
start = time.time() start = time.time()
if self.args.async_update: if self.args.async_update:
self.emb_model.finish_async_update() self.emb_model.finish_async_update()
def fast_train(self): def fast_train(self):
""" fast train with dataloader with only gpu / only cpu""" """fast train with dataloader with only gpu / only cpu"""
self.init_device_emb() self.init_device_emb()
if self.args.async_update: if self.args.async_update:
...@@ -179,8 +224,8 @@ class LineTrainer: ...@@ -179,8 +224,8 @@ class LineTrainer:
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=self.args.num_sampler_threads, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d\n" % num_batches) print("num batchs: %d\n" % num_batches)
...@@ -194,105 +239,220 @@ class LineTrainer: ...@@ -194,105 +239,220 @@ class LineTrainer:
# do negative sampling # do negative sampling
bs = edges.size()[0] bs = edges.size()[0]
neg_nodes = torch.LongTensor( neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table, np.random.choice(
bs * self.args.negative, self.dataset.neg_table,
replace=True)) bs * self.args.negative,
replace=True,
)
)
self.emb_model.fast_learn(edges, neg_nodes=neg_nodes) self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss: if self.args.print_loss:
if self.args.only_fst: if self.args.only_fst:
print("Batch %d time: %.2fs fst-loss: %.4f" \ print(
% (i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval)) "Batch %d time: %.2fs fst-loss: %.4f"
% (
i,
time.time() - start,
-sum(self.emb_model.loss_fst)
/ self.args.print_interval,
)
)
elif self.args.only_snd: elif self.args.only_snd:
print("Batch %d time: %.2fs snd-loss: %.4f" \ print(
% (i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval)) "Batch %d time: %.2fs snd-loss: %.4f"
% (
i,
time.time() - start,
-sum(self.emb_model.loss_snd)
/ self.args.print_interval,
)
)
else: else:
print("Batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \ print(
% (i, time.time()-start, \ "Batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f"
-sum(self.emb_model.loss_fst)/self.args.print_interval, \ % (
-sum(self.emb_model.loss_snd)/self.args.print_interval)) i,
time.time() - start,
-sum(self.emb_model.loss_fst)
/ self.args.print_interval,
-sum(self.emb_model.loss_snd)
/ self.args.print_interval,
)
)
self.emb_model.loss_fst = [] self.emb_model.loss_fst = []
self.emb_model.loss_snd = [] self.emb_model.loss_snd = []
else: else:
print("Batch %d, training time: %.2fs" % (i, time.time()-start)) print(
"Batch %d, training time: %.2fs"
% (i, time.time() - start)
)
start = time.time() start = time.time()
if self.args.async_update: if self.args.async_update:
self.emb_model.finish_async_update() self.emb_model.finish_async_update()
print("Training used time: %.2fs" % (time.time()-start_all)) print("Training used time: %.2fs" % (time.time() - start_all))
if self.args.save_in_pt: if self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_pt(
self.dataset, self.args.output_emb_file
)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(
self.dataset, self.args.output_emb_file
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Implementation of LINE.") parser = argparse.ArgumentParser(description="Implementation of LINE.")
# input files # input files
## personal datasets ## personal datasets
parser.add_argument('--data_file', type=str, parser.add_argument("--data_file", type=str, help="path of dgl graphs")
help="path of dgl graphs")
## ogbl datasets ## ogbl datasets
parser.add_argument('--ogbl_name', type=str, parser.add_argument(
help="name of ogbl dataset, e.g. ogbl-ddi") "--ogbl_name", type=str, help="name of ogbl dataset, e.g. ogbl-ddi"
parser.add_argument('--load_from_ogbl', default=False, action="store_true", )
help="whether load dataset from ogbl") parser.add_argument(
parser.add_argument('--ogbn_name', type=str, "--load_from_ogbl",
help="name of ogbn dataset, e.g. ogbn-proteins") default=False,
parser.add_argument('--load_from_ogbn', default=False, action="store_true", action="store_true",
help="whether load dataset from ogbn") help="whether load dataset from ogbl",
)
parser.add_argument(
"--ogbn_name", type=str, help="name of ogbn dataset, e.g. ogbn-proteins"
)
parser.add_argument(
"--load_from_ogbn",
default=False,
action="store_true",
help="whether load dataset from ogbn",
)
# output files # output files
parser.add_argument('--save_in_pt', default=False, action="store_true", parser.add_argument(
help='Whether save dat in pt format or npy') "--save_in_pt",
parser.add_argument('--output_emb_file', type=str, default="emb.npy", default=False,
help='path of the output npy embedding file') action="store_true",
help="Whether save dat in pt format or npy",
)
parser.add_argument(
"--output_emb_file",
type=str,
default="emb.npy",
help="path of the output npy embedding file",
)
# model parameters # model parameters
parser.add_argument('--dim', default=128, type=int, parser.add_argument(
help="embedding dimensions") "--dim", default=128, type=int, help="embedding dimensions"
parser.add_argument('--num_samples', default=1, type=int, )
help="number of samples during training (million)") parser.add_argument(
parser.add_argument('--negative', default=1, type=int, "--num_samples",
help="negative samples for each positve node pair") default=1,
parser.add_argument('--batch_size', default=128, type=int, type=int,
help="number of edges in each batch") help="number of samples during training (million)",
parser.add_argument('--neg_weight', default=1., type=float, )
help="negative weight") parser.add_argument(
parser.add_argument('--lap_norm', default=0.01, type=float, "--negative",
help="weight of laplacian normalization") default=1,
type=int,
help="negative samples for each positve node pair",
)
parser.add_argument(
"--batch_size",
default=128,
type=int,
help="number of edges in each batch",
)
parser.add_argument(
"--neg_weight", default=1.0, type=float, help="negative weight"
)
parser.add_argument(
"--lap_norm",
default=0.01,
type=float,
help="weight of laplacian normalization",
)
# training parameters # training parameters
parser.add_argument('--only_fst', default=False, action="store_true", parser.add_argument(
help="only do first-order proximity embedding") "--only_fst",
parser.add_argument('--only_snd', default=False, action="store_true", default=False,
help="only do second-order proximity embedding") action="store_true",
parser.add_argument('--print_interval', default=100, type=int, help="only do first-order proximity embedding",
help="number of batches between printing") )
parser.add_argument('--print_loss', default=False, action="store_true", parser.add_argument(
help="whether print loss during training") "--only_snd",
parser.add_argument('--lr', default=0.2, type=float, default=False,
help="learning rate") action="store_true",
help="only do second-order proximity embedding",
)
parser.add_argument(
"--print_interval",
default=100,
type=int,
help="number of batches between printing",
)
parser.add_argument(
"--print_loss",
default=False,
action="store_true",
help="whether print loss during training",
)
parser.add_argument("--lr", default=0.2, type=float, help="learning rate")
# optimization settings # optimization settings
parser.add_argument('--mix', default=False, action="store_true", parser.add_argument(
help="mixed training with CPU and GPU") "--mix",
parser.add_argument('--gpus', type=int, default=[-1], nargs='+', default=False,
help='a list of active gpu ids, e.g. 0, used with --mix') action="store_true",
parser.add_argument('--only_cpu', default=False, action="store_true", help="mixed training with CPU and GPU",
help="training with CPU") )
parser.add_argument('--only_gpu', default=False, action="store_true", parser.add_argument(
help="training with a single GPU (all of the parameters are moved on the GPU)") "--gpus",
parser.add_argument('--async_update', default=False, action="store_true", type=int,
help="mixed training asynchronously, recommend not to use this") default=[-1],
nargs="+",
parser.add_argument('--fast_neg', default=False, action="store_true", help="a list of active gpu ids, e.g. 0, used with --mix",
help="do negative sampling inside a batch") )
parser.add_argument('--num_threads', default=2, type=int, parser.add_argument(
help="number of threads used for each CPU-core/GPU") "--only_cpu",
parser.add_argument('--num_sampler_threads', default=2, type=int, default=False,
help="number of threads used for sampling") action="store_true",
help="training with CPU",
)
parser.add_argument(
"--only_gpu",
default=False,
action="store_true",
help="training with a single GPU (all of the parameters are moved on the GPU)",
)
parser.add_argument(
"--async_update",
default=False,
action="store_true",
help="mixed training asynchronously, recommend not to use this",
)
parser.add_argument(
"--fast_neg",
default=False,
action="store_true",
help="do negative sampling inside a batch",
)
parser.add_argument(
"--num_threads",
default=2,
type=int,
help="number of threads used for each CPU-core/GPU",
)
parser.add_argument(
"--num_sampler_threads",
default=2,
type=int,
help="number of threads used for sampling",
)
args = parser.parse_args() args = parser.parse_args()
......
""" load dataset from ogb """ """ load dataset from ogb """
import argparse import argparse
from ogb.linkproppred import DglLinkPropPredDataset from ogb.linkproppred import DglLinkPropPredDataset
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import dgl import dgl
def load_from_ogbl_with_name(name):
choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'] def load_from_ogbl_with_name(name):
choices = ["ogbl-collab", "ogbl-ddi", "ogbl-ppa", "ogbl-citation"]
assert name in choices, "name must be selected from " + str(choices) assert name in choices, "name must be selected from " + str(choices)
dataset = DglLinkPropPredDataset(name) dataset = DglLinkPropPredDataset(name)
return dataset[0] return dataset[0]
def load_from_ogbn_with_name(name):
choices = ['ogbn-products', 'ogbn-proteins', 'ogbn-arxiv', 'ogbn-papers100M'] def load_from_ogbn_with_name(name):
choices = [
"ogbn-products",
"ogbn-proteins",
"ogbn-arxiv",
"ogbn-papers100M",
]
assert name in choices, "name must be selected from " + str(choices) assert name in choices, "name must be selected from " + str(choices)
dataset, label = DglNodePropPredDataset(name)[0] dataset, label = DglNodePropPredDataset(name)[0]
return dataset return dataset
if __name__ == "__main__": if __name__ == "__main__":
""" load datasets as net.txt format """ """load datasets as net.txt format"""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, parser.add_argument(
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation', "--name",
'ogbn-products', 'ogbn-proteins', 'ogbn-arxiv', 'ogbn-papers100M'], type=str,
default='ogbl-collab', choices=[
help="name of datasets by ogb") "ogbl-collab",
"ogbl-ddi",
"ogbl-ppa",
"ogbl-citation",
"ogbn-products",
"ogbn-proteins",
"ogbn-arxiv",
"ogbn-papers100M",
],
default="ogbl-collab",
help="name of datasets by ogb",
)
args = parser.parse_args() args = parser.parse_args()
name = args.name name = args.name
...@@ -33,4 +54,4 @@ if __name__ == "__main__": ...@@ -33,4 +54,4 @@ if __name__ == "__main__":
else: else:
g = load_from_ogbn_with_name(name=name) g = load_from_ogbn_with_name(name=name)
dgl.save_graphs(name + "-graph.bin", g) dgl.save_graphs(name + "-graph.bin", g)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.nn import init
def init_emb2neg_index(negative, batch_size): def init_emb2neg_index(negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings """select embedding of negative nodes from a batch of node embeddings
for fast negative sampling for fast negative sampling
Return Return
------ ------
index_emb_negu torch.LongTensor : the indices of u_embeddings index_emb_negu torch.LongTensor : the indices of u_embeddings
...@@ -20,7 +22,7 @@ def init_emb2neg_index(negative, batch_size): ...@@ -20,7 +22,7 @@ def init_emb2neg_index(negative, batch_size):
----- -----
# emb_u.shape: [batch_size, dim] # emb_u.shape: [batch_size, dim]
batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu) batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu)
''' """
idx_list_u = list(range(batch_size)) * negative idx_list_u = list(range(batch_size)) * negative
idx_list_v = list(range(batch_size)) * negative idx_list_v = list(range(batch_size)) * negative
random.shuffle(idx_list_v) random.shuffle(idx_list_v)
...@@ -30,21 +32,22 @@ def init_emb2neg_index(negative, batch_size): ...@@ -30,21 +32,22 @@ def init_emb2neg_index(negative, batch_size):
return index_emb_negu, index_emb_negv return index_emb_negu, index_emb_negv
def adam(grad, state_sum, nodes, lr, device, only_gpu): def adam(grad, state_sum, nodes, lr, device, only_gpu):
""" calculate gradients according to adam """ """calculate gradients according to adam"""
grad_sum = (grad * grad).mean(1) grad_sum = (grad * grad).mean(1)
if not only_gpu: if not only_gpu:
grad_sum = grad_sum.cpu() grad_sum = grad_sum.cpu()
state_sum.index_add_(0, nodes, grad_sum) # cpu state_sum.index_add_(0, nodes, grad_sum) # cpu
std = state_sum[nodes].to(device) # gpu std = state_sum[nodes].to(device) # gpu
std_values = std.sqrt_().add_(1e-10).unsqueeze(1) std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
grad = (lr * grad / std_values) # gpu grad = lr * grad / std_values # gpu
return grad return grad
def async_update(num_threads, model, queue): def async_update(num_threads, model, queue):
""" Asynchronous embedding update for entity embeddings. """Asynchronous embedding update for entity embeddings."""
"""
torch.set_num_threads(num_threads) torch.set_num_threads(num_threads)
print("async start") print("async start")
while True: while True:
...@@ -53,20 +56,35 @@ def async_update(num_threads, model, queue): ...@@ -53,20 +56,35 @@ def async_update(num_threads, model, queue):
return return
with torch.no_grad(): with torch.no_grad():
if first_flag: if first_flag:
model.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u) model.fst_u_embeddings.weight.data.index_add_(
model.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v) 0, nodes[:, 0], grad_u
)
model.fst_u_embeddings.weight.data.index_add_(
0, nodes[:, 1], grad_v
)
if neg_nodes is not None: if neg_nodes is not None:
model.fst_u_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg) model.fst_u_embeddings.weight.data.index_add_(
0, neg_nodes, grad_v_neg
)
else: else:
model.snd_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u) model.snd_u_embeddings.weight.data.index_add_(
model.snd_v_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v) 0, nodes[:, 0], grad_u
)
model.snd_v_embeddings.weight.data.index_add_(
0, nodes[:, 1], grad_v
)
if neg_nodes is not None: if neg_nodes is not None:
model.snd_v_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg) model.snd_v_embeddings.weight.data.index_add_(
0, neg_nodes, grad_v_neg
)
class SkipGramModel(nn.Module): class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """ """Negative sampling based skip-gram"""
def __init__(self,
emb_size, def __init__(
self,
emb_size,
emb_dimension, emb_dimension,
batch_size, batch_size,
only_cpu, only_cpu,
...@@ -82,8 +100,8 @@ class SkipGramModel(nn.Module): ...@@ -82,8 +100,8 @@ class SkipGramModel(nn.Module):
record_loss, record_loss,
async_update, async_update,
num_threads, num_threads,
): ):
""" initialize embedding on CPU """initialize embedding on CPU
Paremeters Paremeters
---------- ----------
...@@ -130,7 +148,7 @@ class SkipGramModel(nn.Module): ...@@ -130,7 +148,7 @@ class SkipGramModel(nn.Module):
self.record_loss = record_loss self.record_loss = record_loss
self.async_update = async_update self.async_update = async_update
self.num_threads = num_threads self.num_threads = num_threads
# initialize the device as cpu # initialize the device as cpu
self.device = torch.device("cpu") self.device = torch.device("cpu")
...@@ -138,27 +156,38 @@ class SkipGramModel(nn.Module): ...@@ -138,27 +156,38 @@ class SkipGramModel(nn.Module):
initrange = 1.0 / self.emb_dimension initrange = 1.0 / self.emb_dimension
if self.fst: if self.fst:
self.fst_u_embeddings = nn.Embedding( self.fst_u_embeddings = nn.Embedding(
self.emb_size, self.emb_dimension, sparse=True) self.emb_size, self.emb_dimension, sparse=True
init.uniform_(self.fst_u_embeddings.weight.data, -initrange, initrange) )
init.uniform_(
self.fst_u_embeddings.weight.data, -initrange, initrange
)
if self.snd: if self.snd:
self.snd_u_embeddings = nn.Embedding( self.snd_u_embeddings = nn.Embedding(
self.emb_size, self.emb_dimension, sparse=True) self.emb_size, self.emb_dimension, sparse=True
init.uniform_(self.snd_u_embeddings.weight.data, -initrange, initrange) )
init.uniform_(
self.snd_u_embeddings.weight.data, -initrange, initrange
)
self.snd_v_embeddings = nn.Embedding( self.snd_v_embeddings = nn.Embedding(
self.emb_size, self.emb_dimension, sparse=True) self.emb_size, self.emb_dimension, sparse=True
)
init.constant_(self.snd_v_embeddings.weight.data, 0) init.constant_(self.snd_v_embeddings.weight.data, 0)
# lookup_table is used for fast sigmoid computing # lookup_table is used for fast sigmoid computing
self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)) self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
self.lookup_table[0] = 0. self.lookup_table[0] = 0.0
self.lookup_table[-1] = 1. self.lookup_table[-1] = 1.0
if self.record_loss: if self.record_loss:
self.logsigmoid_table = torch.log(torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))) self.logsigmoid_table = torch.log(
torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
)
self.loss_fst = [] self.loss_fst = []
self.loss_snd = [] self.loss_snd = []
# indexes to select positive/negative node pairs from batch_walks # indexes to select positive/negative node pairs from batch_walks
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(self.negative, self.batch_size) self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.negative, self.batch_size
)
# adam # adam
if self.fst: if self.fst:
...@@ -168,20 +197,20 @@ class SkipGramModel(nn.Module): ...@@ -168,20 +197,20 @@ class SkipGramModel(nn.Module):
self.snd_state_sum_v = torch.zeros(self.emb_size) self.snd_state_sum_v = torch.zeros(self.emb_size)
def create_async_update(self): def create_async_update(self):
""" Set up the async update subprocess. """Set up the async update subprocess."""
"""
self.async_q = Queue(1) self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q)) self.async_p = mp.Process(
target=async_update, args=(self.num_threads, self, self.async_q)
)
self.async_p.start() self.async_p.start()
def finish_async_update(self): def finish_async_update(self):
""" Notify the async update subprocess to quit. """Notify the async update subprocess to quit."""
"""
self.async_q.put((None, None, None, None, None)) self.async_q.put((None, None, None, None, None))
self.async_p.join() self.async_p.join()
def share_memory(self): def share_memory(self):
""" share the parameters across subprocesses """ """share the parameters across subprocesses"""
if self.fst: if self.fst:
self.fst_u_embeddings.weight.share_memory_() self.fst_u_embeddings.weight.share_memory_()
self.fst_state_sum_u.share_memory_() self.fst_state_sum_u.share_memory_()
...@@ -192,7 +221,7 @@ class SkipGramModel(nn.Module): ...@@ -192,7 +221,7 @@ class SkipGramModel(nn.Module):
self.snd_state_sum_v.share_memory_() self.snd_state_sum_v.share_memory_()
def set_device(self, gpu_id): def set_device(self, gpu_id):
""" set gpu device """ """set gpu device"""
self.device = torch.device("cuda:%d" % gpu_id) self.device = torch.device("cuda:%d" % gpu_id)
print("The device is", self.device) print("The device is", self.device)
self.lookup_table = self.lookup_table.to(self.device) self.lookup_table = self.lookup_table.to(self.device)
...@@ -202,7 +231,7 @@ class SkipGramModel(nn.Module): ...@@ -202,7 +231,7 @@ class SkipGramModel(nn.Module):
self.index_emb_negv = self.index_emb_negv.to(self.device) self.index_emb_negv = self.index_emb_negv.to(self.device)
def all_to_device(self, gpu_id): def all_to_device(self, gpu_id):
""" move all of the parameters to a single GPU """ """move all of the parameters to a single GPU"""
self.device = torch.device("cuda:%d" % gpu_id) self.device = torch.device("cuda:%d" % gpu_id)
self.set_device(gpu_id) self.set_device(gpu_id)
if self.fst: if self.fst:
...@@ -215,31 +244,39 @@ class SkipGramModel(nn.Module): ...@@ -215,31 +244,39 @@ class SkipGramModel(nn.Module):
self.snd_state_sum_v = self.snd_state_sum_v.to(self.device) self.snd_state_sum_v = self.snd_state_sum_v.to(self.device)
def fast_sigmoid(self, score): def fast_sigmoid(self, score):
""" do fast sigmoid by looking up in a pre-defined table """ """do fast sigmoid by looking up in a pre-defined table"""
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.lookup_table[idx] return self.lookup_table[idx]
def fast_logsigmoid(self, score): def fast_logsigmoid(self, score):
""" do fast logsigmoid by looking up in a pre-defined table """ """do fast logsigmoid by looking up in a pre-defined table"""
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx] return self.logsigmoid_table[idx]
def fast_pos_bp(self, emb_pos_u, emb_pos_v, first_flag): def fast_pos_bp(self, emb_pos_u, emb_pos_v, first_flag):
""" get grad for positve samples """ """get grad for positve samples"""
pos_score = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1) pos_score = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)
pos_score = torch.clamp(pos_score, max=6, min=-6) pos_score = torch.clamp(pos_score, max=6, min=-6)
# [batch_size, 1] # [batch_size, 1]
score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1) score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)
if self.record_loss: if self.record_loss:
if first_flag: if first_flag:
self.loss_fst.append(torch.mean(self.fast_logsigmoid(pos_score)).item()) self.loss_fst.append(
torch.mean(self.fast_logsigmoid(pos_score)).item()
)
else: else:
self.loss_snd.append(torch.mean(self.fast_logsigmoid(pos_score)).item()) self.loss_snd.append(
torch.mean(self.fast_logsigmoid(pos_score)).item()
)
# [batch_size, dim] # [batch_size, dim]
if self.lap_norm > 0: if self.lap_norm > 0:
grad_u_pos = score * emb_pos_v + self.lap_norm * (emb_pos_v - emb_pos_u) grad_u_pos = score * emb_pos_v + self.lap_norm * (
grad_v_pos = score * emb_pos_u + self.lap_norm * (emb_pos_u - emb_pos_v) emb_pos_v - emb_pos_u
)
grad_v_pos = score * emb_pos_u + self.lap_norm * (
emb_pos_u - emb_pos_v
)
else: else:
grad_u_pos = score * emb_pos_v grad_u_pos = score * emb_pos_v
grad_v_pos = score * emb_pos_u grad_v_pos = score * emb_pos_u
...@@ -247,16 +284,24 @@ class SkipGramModel(nn.Module): ...@@ -247,16 +284,24 @@ class SkipGramModel(nn.Module):
return grad_u_pos, grad_v_pos return grad_u_pos, grad_v_pos
def fast_neg_bp(self, emb_neg_u, emb_neg_v, first_flag): def fast_neg_bp(self, emb_neg_u, emb_neg_v, first_flag):
""" get grad for negative samples """ """get grad for negative samples"""
neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1) neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1)
neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.clamp(neg_score, max=6, min=-6)
# [batch_size * negative, 1] # [batch_size * negative, 1]
score = - self.fast_sigmoid(neg_score).unsqueeze(1) score = -self.fast_sigmoid(neg_score).unsqueeze(1)
if self.record_loss: if self.record_loss:
if first_flag: if first_flag:
self.loss_fst.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item()) self.loss_fst.append(
self.negative
* self.neg_weight
* torch.mean(self.fast_logsigmoid(-neg_score)).item()
)
else: else:
self.loss_snd.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item()) self.loss_snd.append(
self.negative
* self.neg_weight
* torch.mean(self.fast_logsigmoid(-neg_score)).item()
)
grad_u_neg = self.neg_weight * score * emb_neg_v grad_u_neg = self.neg_weight * score * emb_neg_v
grad_v_neg = self.neg_weight * score * emb_neg_u grad_v_neg = self.neg_weight * score * emb_neg_u
...@@ -264,7 +309,7 @@ class SkipGramModel(nn.Module): ...@@ -264,7 +309,7 @@ class SkipGramModel(nn.Module):
return grad_u_neg, grad_v_neg return grad_u_neg, grad_v_neg
def fast_learn(self, batch_edges, neg_nodes=None): def fast_learn(self, batch_edges, neg_nodes=None):
""" Learn a batch of edges in a fast way. It has the following features: """Learn a batch of edges in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation. 1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table. 2. It does sigmoid by a looking up table.
...@@ -296,30 +341,46 @@ class SkipGramModel(nn.Module): ...@@ -296,30 +341,46 @@ class SkipGramModel(nn.Module):
bs = len(nodes) bs = len(nodes)
if self.fst: if self.fst:
emb_u = self.fst_u_embeddings(nodes[:, 0]).view(-1, self.emb_dimension).to(self.device) emb_u = (
emb_v = self.fst_u_embeddings(nodes[:, 1]).view(-1, self.emb_dimension).to(self.device) self.fst_u_embeddings(nodes[:, 0])
.view(-1, self.emb_dimension)
.to(self.device)
)
emb_v = (
self.fst_u_embeddings(nodes[:, 1])
.view(-1, self.emb_dimension)
.to(self.device)
)
## Postive ## Postive
emb_pos_u, emb_pos_v = emb_u, emb_v emb_pos_u, emb_pos_v = emb_u, emb_v
grad_u_pos, grad_v_pos = self.fast_pos_bp(emb_pos_u, emb_pos_v, True) grad_u_pos, grad_v_pos = self.fast_pos_bp(
emb_pos_u, emb_pos_v, True
)
## Negative ## Negative
emb_neg_u = emb_pos_u.repeat((self.negative, 1)) emb_neg_u = emb_pos_u.repeat((self.negative, 1))
if bs < self.batch_size: if bs < self.batch_size:
index_emb_negu, index_emb_negv = init_emb2neg_index(self.negative, bs) index_emb_negu, index_emb_negv = init_emb2neg_index(
self.negative, bs
)
index_emb_negu = index_emb_negu.to(self.device) index_emb_negu = index_emb_negu.to(self.device)
index_emb_negv = index_emb_negv.to(self.device) index_emb_negv = index_emb_negv.to(self.device)
else: else:
index_emb_negu = self.index_emb_negu index_emb_negu = self.index_emb_negu
index_emb_negv = self.index_emb_negv index_emb_negv = self.index_emb_negv
if neg_nodes is None: if neg_nodes is None:
emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv) emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)
else: else:
emb_neg_v = self.fst_u_embeddings.weight[neg_nodes].to(self.device) emb_neg_v = self.fst_u_embeddings.weight[neg_nodes].to(
self.device
)
grad_u_neg, grad_v_neg = self.fast_neg_bp(emb_neg_u, emb_neg_v, True) grad_u_neg, grad_v_neg = self.fast_neg_bp(
emb_neg_u, emb_neg_v, True
)
## Update ## Update
grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg) grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)
...@@ -329,12 +390,33 @@ class SkipGramModel(nn.Module): ...@@ -329,12 +390,33 @@ class SkipGramModel(nn.Module):
grad_v = grad_v_pos grad_v = grad_v_pos
else: else:
grad_v = grad_v_pos grad_v = grad_v_pos
# use adam optimizer # use adam optimizer
grad_u = adam(grad_u, self.fst_state_sum_u, nodes[:, 0], lr, self.device, self.only_gpu) grad_u = adam(
grad_v = adam(grad_v, self.fst_state_sum_u, nodes[:, 1], lr, self.device, self.only_gpu) grad_u,
self.fst_state_sum_u,
nodes[:, 0],
lr,
self.device,
self.only_gpu,
)
grad_v = adam(
grad_v,
self.fst_state_sum_u,
nodes[:, 1],
lr,
self.device,
self.only_gpu,
)
if neg_nodes is not None: if neg_nodes is not None:
grad_v_neg = adam(grad_v_neg, self.fst_state_sum_u, neg_nodes, lr, self.device, self.only_gpu) grad_v_neg = adam(
grad_v_neg,
self.fst_state_sum_u,
neg_nodes,
lr,
self.device,
self.only_gpu,
)
if self.mixed_train: if self.mixed_train:
grad_u = grad_u.cpu() grad_u = grad_u.cpu()
...@@ -351,27 +433,47 @@ class SkipGramModel(nn.Module): ...@@ -351,27 +433,47 @@ class SkipGramModel(nn.Module):
if neg_nodes is not None: if neg_nodes is not None:
neg_nodes.share_memory_() neg_nodes.share_memory_()
grad_v_neg.share_memory_() grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes, True)) self.async_q.put(
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes, True)
)
if not self.async_update: if not self.async_update:
self.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u) self.fst_u_embeddings.weight.data.index_add_(
self.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v) 0, nodes[:, 0], grad_u
)
self.fst_u_embeddings.weight.data.index_add_(
0, nodes[:, 1], grad_v
)
if neg_nodes is not None: if neg_nodes is not None:
self.fst_u_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg) self.fst_u_embeddings.weight.data.index_add_(
0, neg_nodes, grad_v_neg
)
if self.snd: if self.snd:
emb_u = self.snd_u_embeddings(nodes[:, 0]).view(-1, self.emb_dimension).to(self.device) emb_u = (
emb_v = self.snd_v_embeddings(nodes[:, 1]).view(-1, self.emb_dimension).to(self.device) self.snd_u_embeddings(nodes[:, 0])
.view(-1, self.emb_dimension)
.to(self.device)
)
emb_v = (
self.snd_v_embeddings(nodes[:, 1])
.view(-1, self.emb_dimension)
.to(self.device)
)
## Postive ## Postive
emb_pos_u, emb_pos_v = emb_u, emb_v emb_pos_u, emb_pos_v = emb_u, emb_v
grad_u_pos, grad_v_pos = self.fast_pos_bp(emb_pos_u, emb_pos_v, False) grad_u_pos, grad_v_pos = self.fast_pos_bp(
emb_pos_u, emb_pos_v, False
)
## Negative ## Negative
emb_neg_u = emb_pos_u.repeat((self.negative, 1)) emb_neg_u = emb_pos_u.repeat((self.negative, 1))
if bs < self.batch_size: if bs < self.batch_size:
index_emb_negu, index_emb_negv = init_emb2neg_index(self.negative, bs) index_emb_negu, index_emb_negv = init_emb2neg_index(
self.negative, bs
)
index_emb_negu = index_emb_negu.to(self.device) index_emb_negu = index_emb_negu.to(self.device)
index_emb_negv = index_emb_negv.to(self.device) index_emb_negv = index_emb_negv.to(self.device)
else: else:
...@@ -381,9 +483,13 @@ class SkipGramModel(nn.Module): ...@@ -381,9 +483,13 @@ class SkipGramModel(nn.Module):
if neg_nodes is None: if neg_nodes is None:
emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv) emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)
else: else:
emb_neg_v = self.snd_v_embeddings.weight[neg_nodes].to(self.device) emb_neg_v = self.snd_v_embeddings.weight[neg_nodes].to(
self.device
)
grad_u_neg, grad_v_neg = self.fast_neg_bp(emb_neg_u, emb_neg_v, False) grad_u_neg, grad_v_neg = self.fast_neg_bp(
emb_neg_u, emb_neg_v, False
)
## Update ## Update
grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg) grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)
...@@ -393,12 +499,33 @@ class SkipGramModel(nn.Module): ...@@ -393,12 +499,33 @@ class SkipGramModel(nn.Module):
grad_v = grad_v_pos grad_v = grad_v_pos
else: else:
grad_v = grad_v_pos grad_v = grad_v_pos
# use adam optimizer # use adam optimizer
grad_u = adam(grad_u, self.snd_state_sum_u, nodes[:, 0], lr, self.device, self.only_gpu) grad_u = adam(
grad_v = adam(grad_v, self.snd_state_sum_v, nodes[:, 1], lr, self.device, self.only_gpu) grad_u,
self.snd_state_sum_u,
nodes[:, 0],
lr,
self.device,
self.only_gpu,
)
grad_v = adam(
grad_v,
self.snd_state_sum_v,
nodes[:, 1],
lr,
self.device,
self.only_gpu,
)
if neg_nodes is not None: if neg_nodes is not None:
grad_v_neg = adam(grad_v_neg, self.snd_state_sum_v, neg_nodes, lr, self.device, self.only_gpu) grad_v_neg = adam(
grad_v_neg,
self.snd_state_sum_v,
neg_nodes,
lr,
self.device,
self.only_gpu,
)
if self.mixed_train: if self.mixed_train:
grad_u = grad_u.cpu() grad_u = grad_u.cpu()
...@@ -415,37 +542,51 @@ class SkipGramModel(nn.Module): ...@@ -415,37 +542,51 @@ class SkipGramModel(nn.Module):
if neg_nodes is not None: if neg_nodes is not None:
neg_nodes.share_memory_() neg_nodes.share_memory_()
grad_v_neg.share_memory_() grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes, False)) self.async_q.put(
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes, False)
)
if not self.async_update: if not self.async_update:
self.snd_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u) self.snd_u_embeddings.weight.data.index_add_(
self.snd_v_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v) 0, nodes[:, 0], grad_u
)
self.snd_v_embeddings.weight.data.index_add_(
0, nodes[:, 1], grad_v
)
if neg_nodes is not None: if neg_nodes is not None:
self.snd_v_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg) self.snd_v_embeddings.weight.data.index_add_(
0, neg_nodes, grad_v_neg
)
return return
def get_embedding(self): def get_embedding(self):
if self.fst: if self.fst:
embedding_fst = self.fst_u_embeddings.weight.cpu().data.numpy() embedding_fst = self.fst_u_embeddings.weight.cpu().data.numpy()
embedding_fst /= np.sqrt(np.sum(embedding_fst * embedding_fst, 1)).reshape(-1, 1) embedding_fst /= np.sqrt(
np.sum(embedding_fst * embedding_fst, 1)
).reshape(-1, 1)
if self.snd: if self.snd:
embedding_snd = self.snd_u_embeddings.weight.cpu().data.numpy() embedding_snd = self.snd_u_embeddings.weight.cpu().data.numpy()
embedding_snd /= np.sqrt(np.sum(embedding_snd * embedding_snd, 1)).reshape(-1, 1) embedding_snd /= np.sqrt(
np.sum(embedding_snd * embedding_snd, 1)
).reshape(-1, 1)
if self.fst and self.snd: if self.fst and self.snd:
embedding = np.concatenate((embedding_fst, embedding_snd), 1) embedding = np.concatenate((embedding_fst, embedding_snd), 1)
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1) embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(
-1, 1
)
elif self.fst and not self.snd: elif self.fst and not self.snd:
embedding = embedding_fst embedding = embedding_fst
elif self.snd and not self.fst: elif self.snd and not self.fst:
embedding = embedding_snd embedding = embedding_snd
else: else:
pass pass
return embedding return embedding
def save_embedding(self, dataset, file_name): def save_embedding(self, dataset, file_name):
""" Write embedding to local file. Only used when node ids are numbers. """Write embedding to local file. Only used when node ids are numbers.
Parameter Parameter
--------- ---------
...@@ -456,7 +597,7 @@ class SkipGramModel(nn.Module): ...@@ -456,7 +597,7 @@ class SkipGramModel(nn.Module):
np.save(file_name, embedding) np.save(file_name, embedding)
def save_embedding_pt(self, dataset, file_name): def save_embedding_pt(self, dataset, file_name):
""" For ogb leaderboard. """ """For ogb leaderboard."""
embedding = torch.Tensor(self.get_embedding()).cpu() embedding = torch.Tensor(self.get_embedding()).cpu()
embedding_empty = torch.zeros_like(embedding.data) embedding_empty = torch.zeros_like(embedding.data)
valid_nodes = torch.LongTensor(dataset.valid_nodes) valid_nodes = torch.LongTensor(dataset.valid_nodes)
......
import os import os
import pickle
import random
import time
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import pickle
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_archive
import random
import time
import dgl import dgl
from dgl.data.utils import (
_get_dgl_url,
download,
extract_archive,
get_download_dir,
)
def ReadTxtNet(file_path="", undirected=True): def ReadTxtNet(file_path="", undirected=True):
""" Read the txt network file. """Read the txt network file.
Notations: The network is unweighted. Notations: The network is unweighted.
Parameters Parameters
...@@ -21,16 +29,20 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -21,16 +29,20 @@ def ReadTxtNet(file_path="", undirected=True):
Return Return
------ ------
net dict : a dict recording the connections in the graph net dict : a dict recording the connections in the graph
node2id dict : a dict mapping the nodes to their embedding indices node2id dict : a dict mapping the nodes to their embedding indices
id2node dict : a dict mapping nodes embedding indices to the nodes id2node dict : a dict mapping nodes embedding indices to the nodes
""" """
if file_path == 'youtube' or file_path == 'blog': if file_path == "youtube" or file_path == "blog":
name = file_path name = file_path
dir = get_download_dir() dir = get_download_dir()
zip_file_path='{}/{}.zip'.format(dir, name) zip_file_path = "{}/{}.zip".format(dir, name)
download(_get_dgl_url(os.path.join('dataset/DeepWalk/', '{}.zip'.format(file_path))), path=zip_file_path) download(
extract_archive(zip_file_path, _get_dgl_url(
'{}/{}'.format(dir, name)) os.path.join("dataset/DeepWalk/", "{}.zip".format(file_path))
),
path=zip_file_path,
)
extract_archive(zip_file_path, "{}/{}".format(dir, name))
file_path = "{}/{}/{}-net.txt".format(dir, name, name) file_path = "{}/{}/{}-net.txt".format(dir, name, name)
node2id = {} node2id = {}
...@@ -44,7 +56,10 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -44,7 +56,10 @@ def ReadTxtNet(file_path="", undirected=True):
with open(file_path, "r") as f: with open(file_path, "r") as f:
for line in f.readlines(): for line in f.readlines():
tup = list(map(int, line.strip().split(" "))) tup = list(map(int, line.strip().split(" ")))
assert len(tup) in [2, 3], "The format of network file is unrecognizable." assert len(tup) in [
2,
3,
], "The format of network file is unrecognizable."
if len(tup) == 3: if len(tup) == 3:
n1, n2, w = tup n1, n2, w = tup
elif len(tup) == 2: elif len(tup) == 2:
...@@ -71,7 +86,7 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -71,7 +86,7 @@ def ReadTxtNet(file_path="", undirected=True):
src.append(n1) src.append(n1)
dst.append(n2) dst.append(n2)
weight.append(w) weight.append(w)
if undirected: if undirected:
if n2 not in net: if n2 not in net:
net[n2] = {n1: w} net[n2] = {n1: w}
...@@ -88,16 +103,15 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -88,16 +103,15 @@ def ReadTxtNet(file_path="", undirected=True):
print("edge num: %d" % len(src)) print("edge num: %d" % len(src))
assert max(net.keys()) == len(net) - 1, "error reading net, quit" assert max(net.keys()) == len(net) - 1, "error reading net, quit"
sm = sp.coo_matrix( sm = sp.coo_matrix((np.array(weight), (src, dst)), dtype=np.float32)
(np.array(weight), (src, dst)),
dtype=np.float32)
return net, node2id, id2node, sm return net, node2id, id2node, sm
def net2graph(net_sm): def net2graph(net_sm):
""" Transform the network to DGL graph """Transform the network to DGL graph
Return Return
------ ------
G DGLGraph : graph by DGL G DGLGraph : graph by DGL
""" """
...@@ -108,29 +122,33 @@ def net2graph(net_sm): ...@@ -108,29 +122,33 @@ def net2graph(net_sm):
print("Building DGLGraph in %.2fs" % t) print("Building DGLGraph in %.2fs" % t)
return G return G
def make_undirected(G): def make_undirected(G):
#G.readonly(False) # G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0]) G.add_edges(G.edges()[1], G.edges()[0])
return G return G
def find_connected_nodes(G): def find_connected_nodes(G):
nodes = torch.nonzero(G.out_degrees(), as_tuple=False).squeeze(-1) nodes = torch.nonzero(G.out_degrees(), as_tuple=False).squeeze(-1)
return nodes return nodes
class LineDataset: class LineDataset:
def __init__(self, def __init__(
net_file, self,
batch_size, net_file,
num_samples, batch_size,
negative=5, num_samples,
gpus=[0], negative=5,
fast_neg=True, gpus=[0],
ogbl_name="", fast_neg=True,
load_from_ogbl=False, ogbl_name="",
ogbn_name="", load_from_ogbl=False,
load_from_ogbn=False, ogbn_name="",
): load_from_ogbn=False,
""" This class has the following functions: ):
"""This class has the following functions:
1. Transform the txt network file into DGL graph; 1. Transform the txt network file into DGL graph;
2. Generate random walk sequences for the trainer; 2. Generate random walk sequences for the trainer;
3. Provide the negative table if the user hopes to sample negative 3. Provide the negative table if the user hopes to sample negative
...@@ -153,12 +171,18 @@ class LineDataset: ...@@ -153,12 +171,18 @@ class LineDataset:
self.fast_neg = fast_neg self.fast_neg = fast_neg
if load_from_ogbl: if load_from_ogbl:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training." assert (
len(gpus) == 1
), "ogb.linkproppred is not compatible with multi-gpu training."
from load_dataset import load_from_ogbl_with_name from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name) self.G = load_from_ogbl_with_name(ogbl_name)
elif load_from_ogbn: elif load_from_ogbn:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training." assert (
len(gpus) == 1
), "ogb.linkproppred is not compatible with multi-gpu training."
from load_dataset import load_from_ogbn_with_name from load_dataset import load_from_ogbn_with_name
self.G = load_from_ogbn_with_name(ogbn_name) self.G = load_from_ogbn_with_name(ogbn_name)
else: else:
self.G = dgl.load_graphs(net_file)[0][0] self.G = dgl.load_graphs(net_file)[0][0]
...@@ -168,12 +192,14 @@ class LineDataset: ...@@ -168,12 +192,14 @@ class LineDataset:
self.num_nodes = self.G.number_of_nodes() self.num_nodes = self.G.number_of_nodes()
start = time.time() start = time.time()
seeds = np.random.choice(np.arange(self.G.number_of_edges()), seeds = np.random.choice(
self.num_samples, np.arange(self.G.number_of_edges()), self.num_samples, replace=True
replace=True) # edge index ) # edge index
self.seeds = torch.split(torch.LongTensor(seeds), self.seeds = torch.split(
int(np.ceil(self.num_samples / self.num_procs)), torch.LongTensor(seeds),
0) int(np.ceil(self.num_samples / self.num_procs)),
0,
)
end = time.time() end = time.time()
t = end - start t = end - start
print("generate %d samples in %.2fs" % (len(seeds), t)) print("generate %d samples in %.2fs" % (len(seeds), t))
...@@ -186,7 +212,7 @@ class LineDataset: ...@@ -186,7 +212,7 @@ class LineDataset:
node_degree /= np.sum(node_degree) node_degree /= np.sum(node_degree)
node_degree = np.array(node_degree * 1e8, dtype=np.int) node_degree = np.array(node_degree * 1e8, dtype=np.int)
self.neg_table = [] self.neg_table = []
for idx, node in enumerate(self.valid_nodes): for idx, node in enumerate(self.valid_nodes):
self.neg_table += [node] * node_degree[idx] self.neg_table += [node] * node_degree[idx]
self.neg_table_size = len(self.neg_table) self.neg_table_size = len(self.neg_table)
...@@ -194,19 +220,22 @@ class LineDataset: ...@@ -194,19 +220,22 @@ class LineDataset:
del node_degree del node_degree
def create_sampler(self, i): def create_sampler(self, i):
""" create random walk sampler """ """create random walk sampler"""
return EdgeSampler(self.G, self.seeds[i]) return EdgeSampler(self.G, self.seeds[i])
def save_mapping(self, map_file): def save_mapping(self, map_file):
with open(map_file, "wb") as f: with open(map_file, "wb") as f:
pickle.dump(self.node2id, f) pickle.dump(self.node2id, f)
class EdgeSampler(object): class EdgeSampler(object):
def __init__(self, G, seeds): def __init__(self, G, seeds):
self.G = G self.G = G
self.seeds = seeds self.seeds = seeds
self.edges = torch.cat((self.G.edges()[0].unsqueeze(0), self.G.edges()[1].unsqueeze(0)), 0).t() self.edges = torch.cat(
(self.G.edges()[0].unsqueeze(0), self.G.edges()[1].unsqueeze(0)), 0
).t()
def sample(self, seeds): def sample(self, seeds):
""" seeds torch.LongTensor : a batch of indices of edges """ """seeds torch.LongTensor : a batch of indices of edges"""
return self.edges[torch.LongTensor(seeds)] return self.edges[torch.LongTensor(seeds)]
import torch import torch
def check_args(args): def check_args(args):
flag = sum([args.only_1st, args.only_2nd]) flag = sum([args.only_1st, args.only_2nd])
assert flag <= 1, "no more than one selection from --only_1st and --only_2nd" assert (
flag <= 1
), "no more than one selection from --only_1st and --only_2nd"
if flag == 0: if flag == 0:
assert args.dim % 2 == 0, "embedding dimension must be an even number" assert args.dim % 2 == 0, "embedding dimension must be an even number"
if args.async_update: if args.async_update:
assert args.mix, "please use --async_update with --mix" assert args.mix, "please use --async_update with --mix"
def sum_up_params(model): def sum_up_params(model):
""" Count the model parameters """ """Count the model parameters"""
n = [] n = []
if model.fst: if model.fst:
p = model.fst_u_embeddings.weight.cpu().data.numel() p = model.fst_u_embeddings.weight.cpu().data.numel()
......
...@@ -3,14 +3,14 @@ import math ...@@ -3,14 +3,14 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from torch.nn import Linear from torch.nn import Linear
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
from dgl.nn.pytorch import GraphConv, SAGEConv
from dgl.dataloading.negative_sampler import GlobalUniform from dgl.dataloading.negative_sampler import GlobalUniform
from dgl.nn.pytorch import GraphConv, SAGEConv
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
class Logger(object): class Logger(object):
def __init__(self, runs, info=None): def __init__(self, runs, info=None):
...@@ -56,9 +56,13 @@ class Logger(object): ...@@ -56,9 +56,13 @@ class Logger(object):
class NGNN_GCNConv(torch.nn.Module): class NGNN_GCNConv(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_nonl_layers): def __init__(
self, in_channels, hidden_channels, out_channels, num_nonl_layers
):
super(NGNN_GCNConv, self).__init__() super(NGNN_GCNConv, self).__init__()
self.num_nonl_layers = num_nonl_layers # number of nonlinear layers in each conv layer self.num_nonl_layers = (
num_nonl_layers # number of nonlinear layers in each conv layer
)
self.conv = GraphConv(in_channels, hidden_channels) self.conv = GraphConv(in_channels, hidden_channels)
self.fc = Linear(hidden_channels, hidden_channels) self.fc = Linear(hidden_channels, hidden_channels)
self.fc2 = Linear(hidden_channels, out_channels) self.fc2 = Linear(hidden_channels, out_channels)
...@@ -66,7 +70,7 @@ class NGNN_GCNConv(torch.nn.Module): ...@@ -66,7 +70,7 @@ class NGNN_GCNConv(torch.nn.Module):
def reset_parameters(self): def reset_parameters(self):
self.conv.reset_parameters() self.conv.reset_parameters()
gain = torch.nn.init.calculate_gain('relu') gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain) torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain) torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
for bias in [self.fc.bias, self.fc2.bias]: for bias in [self.fc.bias, self.fc2.bias]:
...@@ -79,28 +83,54 @@ class NGNN_GCNConv(torch.nn.Module): ...@@ -79,28 +83,54 @@ class NGNN_GCNConv(torch.nn.Module):
if self.num_nonl_layers == 2: if self.num_nonl_layers == 2:
x = F.relu(x) x = F.relu(x)
x = self.fc(x) x = self.fc(x)
x = F.relu(x) x = F.relu(x)
x = self.fc2(x) x = self.fc2(x)
return x return x
class GCN(torch.nn.Module): class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, ngnn_type, dataset): def __init__(
self,
in_channels,
hidden_channels,
out_channels,
num_layers,
dropout,
ngnn_type,
dataset,
):
super(GCN, self).__init__() super(GCN, self).__init__()
self.dataset = dataset self.dataset = dataset
self.convs = torch.nn.ModuleList() self.convs = torch.nn.ModuleList()
num_nonl_layers = 1 if num_layers <= 2 else 2 # number of nonlinear layers in each conv layer num_nonl_layers = (
if ngnn_type == 'input': 1 if num_layers <= 2 else 2
self.convs.append(NGNN_GCNConv(in_channels, hidden_channels, hidden_channels, num_nonl_layers)) ) # number of nonlinear layers in each conv layer
if ngnn_type == "input":
self.convs.append(
NGNN_GCNConv(
in_channels,
hidden_channels,
hidden_channels,
num_nonl_layers,
)
)
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.convs.append(GraphConv(hidden_channels, hidden_channels)) self.convs.append(GraphConv(hidden_channels, hidden_channels))
elif ngnn_type == 'hidden': elif ngnn_type == "hidden":
self.convs.append(GraphConv(in_channels, hidden_channels)) self.convs.append(GraphConv(in_channels, hidden_channels))
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.convs.append(NGNN_GCNConv(hidden_channels, hidden_channels, hidden_channels, num_nonl_layers)) self.convs.append(
NGNN_GCNConv(
hidden_channels,
hidden_channels,
hidden_channels,
num_nonl_layers,
)
)
self.convs.append(GraphConv(hidden_channels, out_channels)) self.convs.append(GraphConv(hidden_channels, out_channels))
self.dropout = dropout self.dropout = dropout
...@@ -120,10 +150,19 @@ class GCN(torch.nn.Module): ...@@ -120,10 +150,19 @@ class GCN(torch.nn.Module):
class NGNN_SAGEConv(torch.nn.Module): class NGNN_SAGEConv(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_nonl_layers, def __init__(
*, reduce): self,
in_channels,
hidden_channels,
out_channels,
num_nonl_layers,
*,
reduce,
):
super(NGNN_SAGEConv, self).__init__() super(NGNN_SAGEConv, self).__init__()
self.num_nonl_layers = num_nonl_layers # number of nonlinear layers in each conv layer self.num_nonl_layers = (
num_nonl_layers # number of nonlinear layers in each conv layer
)
self.conv = SAGEConv(in_channels, hidden_channels, reduce) self.conv = SAGEConv(in_channels, hidden_channels, reduce)
self.fc = Linear(hidden_channels, hidden_channels) self.fc = Linear(hidden_channels, hidden_channels)
self.fc2 = Linear(hidden_channels, out_channels) self.fc2 = Linear(hidden_channels, out_channels)
...@@ -131,7 +170,7 @@ class NGNN_SAGEConv(torch.nn.Module): ...@@ -131,7 +170,7 @@ class NGNN_SAGEConv(torch.nn.Module):
def reset_parameters(self): def reset_parameters(self):
self.conv.reset_parameters() self.conv.reset_parameters()
gain = torch.nn.init.calculate_gain('relu') gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain) torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain) torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
for bias in [self.fc.bias, self.fc2.bias]: for bias in [self.fc.bias, self.fc2.bias]:
...@@ -144,28 +183,59 @@ class NGNN_SAGEConv(torch.nn.Module): ...@@ -144,28 +183,59 @@ class NGNN_SAGEConv(torch.nn.Module):
if self.num_nonl_layers == 2: if self.num_nonl_layers == 2:
x = F.relu(x) x = F.relu(x)
x = self.fc(x) x = self.fc(x)
x = F.relu(x) x = F.relu(x)
x = self.fc2(x) x = self.fc2(x)
return x return x
class SAGE(torch.nn.Module): class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, ngnn_type, dataset, reduce='mean'): def __init__(
self,
in_channels,
hidden_channels,
out_channels,
num_layers,
dropout,
ngnn_type,
dataset,
reduce="mean",
):
super(SAGE, self).__init__() super(SAGE, self).__init__()
self.dataset = dataset self.dataset = dataset
self.convs = torch.nn.ModuleList() self.convs = torch.nn.ModuleList()
num_nonl_layers = 1 if num_layers <= 2 else 2 # number of nonlinear layers in each conv layer num_nonl_layers = (
if ngnn_type == 'input': 1 if num_layers <= 2 else 2
self.convs.append(NGNN_SAGEConv(in_channels, hidden_channels, hidden_channels, num_nonl_layers, reduce=reduce)) ) # number of nonlinear layers in each conv layer
if ngnn_type == "input":
self.convs.append(
NGNN_SAGEConv(
in_channels,
hidden_channels,
hidden_channels,
num_nonl_layers,
reduce=reduce,
)
)
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels, reduce)) self.convs.append(
elif ngnn_type == 'hidden': SAGEConv(hidden_channels, hidden_channels, reduce)
)
elif ngnn_type == "hidden":
self.convs.append(SAGEConv(in_channels, hidden_channels, reduce)) self.convs.append(SAGEConv(in_channels, hidden_channels, reduce))
for _ in range(num_layers - 2): for _ in range(num_layers - 2):
self.convs.append(NGNN_SAGEConv(hidden_channels, hidden_channels, hidden_channels, num_nonl_layers, reduce=reduce)) self.convs.append(
NGNN_SAGEConv(
hidden_channels,
hidden_channels,
hidden_channels,
num_nonl_layers,
reduce=reduce,
)
)
self.convs.append(SAGEConv(hidden_channels, out_channels, reduce)) self.convs.append(SAGEConv(hidden_channels, out_channels, reduce))
self.dropout = dropout self.dropout = dropout
...@@ -185,7 +255,9 @@ class SAGE(torch.nn.Module): ...@@ -185,7 +255,9 @@ class SAGE(torch.nn.Module):
class LinkPredictor(torch.nn.Module): class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): def __init__(
self, in_channels, hidden_channels, out_channels, num_layers, dropout
):
super(LinkPredictor, self).__init__() super(LinkPredictor, self).__init__()
self.lins = torch.nn.ModuleList() self.lins = torch.nn.ModuleList()
...@@ -215,11 +287,12 @@ def train(model, predictor, g, x, split_edge, optimizer, batch_size): ...@@ -215,11 +287,12 @@ def train(model, predictor, g, x, split_edge, optimizer, batch_size):
model.train() model.train()
predictor.train() predictor.train()
pos_train_edge = split_edge['train']['edge'].to(x.device) pos_train_edge = split_edge["train"]["edge"].to(x.device)
neg_sampler = GlobalUniform(1) neg_sampler = GlobalUniform(1)
total_loss = total_examples = 0 total_loss = total_examples = 0
for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, for perm in DataLoader(
shuffle=True): range(pos_train_edge.size(0)), batch_size, shuffle=True
):
optimizer.zero_grad() optimizer.zero_grad()
h = model(g, x) h = model(g, x)
...@@ -237,7 +310,7 @@ def train(model, predictor, g, x, split_edge, optimizer, batch_size): ...@@ -237,7 +310,7 @@ def train(model, predictor, g, x, split_edge, optimizer, batch_size):
loss = pos_loss + neg_loss loss = pos_loss + neg_loss
loss.backward() loss.backward()
if model.dataset == 'ogbl-ddi': if model.dataset == "ogbl-ddi":
torch.nn.utils.clip_grad_norm_(x, 1.0) torch.nn.utils.clip_grad_norm_(x, 1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
...@@ -258,11 +331,11 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size): ...@@ -258,11 +331,11 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size):
h = model(g, x) h = model(g, x)
pos_train_edge = split_edge['eval_train']['edge'].to(h.device) pos_train_edge = split_edge["eval_train"]["edge"].to(h.device)
pos_valid_edge = split_edge['valid']['edge'].to(h.device) pos_valid_edge = split_edge["valid"]["edge"].to(h.device)
neg_valid_edge = split_edge['valid']['edge_neg'].to(h.device) neg_valid_edge = split_edge["valid"]["edge_neg"].to(h.device)
pos_test_edge = split_edge['test']['edge'].to(h.device) pos_test_edge = split_edge["test"]["edge"].to(h.device)
neg_test_edge = split_edge['test']['edge_neg'].to(h.device) neg_test_edge = split_edge["test"]["edge_neg"].to(h.device)
def get_pred(test_edges, h): def get_pred(test_edges, h):
preds = [] preds = []
...@@ -271,7 +344,7 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size): ...@@ -271,7 +344,7 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size):
preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
pred = torch.cat(preds, dim=0) pred = torch.cat(preds, dim=0)
return pred return pred
pos_train_pred = get_pred(pos_train_edge, h) pos_train_pred = get_pred(pos_train_edge, h)
pos_valid_pred = get_pred(pos_valid_edge, h) pos_valid_pred = get_pred(pos_valid_edge, h)
neg_valid_pred = get_pred(neg_valid_edge, h) neg_valid_pred = get_pred(neg_valid_edge, h)
...@@ -281,50 +354,84 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size): ...@@ -281,50 +354,84 @@ def test(model, predictor, g, x, split_edge, evaluator, batch_size):
results = {} results = {}
for K in [20, 50, 100]: for K in [20, 50, 100]:
evaluator.K = K evaluator.K = K
train_hits = evaluator.eval({ train_hits = evaluator.eval(
'y_pred_pos': pos_train_pred, {
'y_pred_neg': neg_valid_pred, "y_pred_pos": pos_train_pred,
})[f'hits@{K}'] "y_pred_neg": neg_valid_pred,
valid_hits = evaluator.eval({ }
'y_pred_pos': pos_valid_pred, )[f"hits@{K}"]
'y_pred_neg': neg_valid_pred, valid_hits = evaluator.eval(
})[f'hits@{K}'] {
test_hits = evaluator.eval({ "y_pred_pos": pos_valid_pred,
'y_pred_pos': pos_test_pred, "y_pred_neg": neg_valid_pred,
'y_pred_neg': neg_test_pred, }
})[f'hits@{K}'] )[f"hits@{K}"]
test_hits = evaluator.eval(
results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) {
"y_pred_pos": pos_test_pred,
"y_pred_neg": neg_test_pred,
}
)[f"hits@{K}"]
results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits)
return results return results
def main(): def main():
parser = argparse.ArgumentParser(description='OGBL(Full Batch GCN/GraphSage + NGNN)') parser = argparse.ArgumentParser(
description="OGBL(Full Batch GCN/GraphSage + NGNN)"
)
# dataset setting # dataset setting
parser.add_argument('--dataset', type=str, default='ogbl-ddi', choices=['ogbl-ddi', 'ogbl-collab', 'ogbl-ppa']) parser.add_argument(
"--dataset",
type=str,
default="ogbl-ddi",
choices=["ogbl-ddi", "ogbl-collab", "ogbl-ppa"],
)
# device setting # device setting
parser.add_argument('--device', type=int, default=0, help='GPU device ID. Use -1 for CPU training.') parser.add_argument(
"--device",
type=int,
default=0,
help="GPU device ID. Use -1 for CPU training.",
)
# model structure settings # model structure settings
parser.add_argument('--use_sage', action='store_true', help='If not set, use GCN by default.') parser.add_argument(
parser.add_argument('--ngnn_type', type=str, default="input", choices=['input', 'hidden'], help="You can set this value from 'input' or 'hidden' to apply NGNN to different GNN layers.") "--use_sage",
parser.add_argument('--num_layers', type=int, default=3, help='number of GNN layers') action="store_true",
parser.add_argument('--hidden_channels', type=int, default=256) help="If not set, use GCN by default.",
parser.add_argument('--dropout', type=float, default=0.0) )
parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument(
parser.add_argument('--lr', type=float, default=0.001) "--ngnn_type",
parser.add_argument('--epochs', type=int, default=400) type=str,
default="input",
choices=["input", "hidden"],
help="You can set this value from 'input' or 'hidden' to apply NGNN to different GNN layers.",
)
parser.add_argument(
"--num_layers", type=int, default=3, help="number of GNN layers"
)
parser.add_argument("--hidden_channels", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--batch_size", type=int, default=64 * 1024)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=400)
# training settings # training settings
parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument("--eval_steps", type=int, default=1)
parser.add_argument('--runs', type=int, default=10) parser.add_argument("--runs", type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
device = f'cuda:{args.device}' if args.device != -1 and torch.cuda.is_available() else 'cpu' device = (
f"cuda:{args.device}"
if args.device != -1 and torch.cuda.is_available()
else "cpu"
)
device = torch.device(device) device = torch.device(device)
dataset = DglLinkPropPredDataset(name=args.dataset) dataset = DglLinkPropPredDataset(name=args.dataset)
...@@ -332,70 +439,101 @@ def main(): ...@@ -332,70 +439,101 @@ def main():
split_edge = dataset.get_edge_split() split_edge = dataset.get_edge_split()
# We randomly pick some training samples that we want to evaluate on: # We randomly pick some training samples that we want to evaluate on:
idx = torch.randperm(split_edge['train']['edge'].size(0)) idx = torch.randperm(split_edge["train"]["edge"].size(0))
idx = idx[:split_edge['valid']['edge'].size(0)] idx = idx[: split_edge["valid"]["edge"].size(0)]
split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]} split_edge["eval_train"] = {"edge": split_edge["train"]["edge"][idx]}
if dataset.name == 'ogbl-ppa': if dataset.name == "ogbl-ppa":
g.ndata['feat'] = g.ndata['feat'].to(torch.float) g.ndata["feat"] = g.ndata["feat"].to(torch.float)
if dataset.name == 'ogbl-ddi': if dataset.name == "ogbl-ddi":
emb = torch.nn.Embedding(g.num_nodes(), args.hidden_channels).to(device) emb = torch.nn.Embedding(g.num_nodes(), args.hidden_channels).to(device)
in_channels = args.hidden_channels in_channels = args.hidden_channels
else: # ogbl-collab, ogbl-ppa else: # ogbl-collab, ogbl-ppa
in_channels = g.ndata['feat'].size(-1) in_channels = g.ndata["feat"].size(-1)
# select model # select model
if args.use_sage: if args.use_sage:
model = SAGE(in_channels, args.hidden_channels, model = SAGE(
args.hidden_channels, args.num_layers, in_channels,
args.dropout, args.ngnn_type, dataset.name) args.hidden_channels,
else: # GCN args.hidden_channels,
args.num_layers,
args.dropout,
args.ngnn_type,
dataset.name,
)
else: # GCN
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
model = GCN(in_channels, args.hidden_channels, model = GCN(
args.hidden_channels, args.num_layers, in_channels,
args.dropout, args.ngnn_type, dataset.name) args.hidden_channels,
args.hidden_channels,
predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, 3, args.dropout) args.num_layers,
args.dropout,
args.ngnn_type,
dataset.name,
)
predictor = LinkPredictor(
args.hidden_channels, args.hidden_channels, 1, 3, args.dropout
)
g, model, predictor = map(lambda x: x.to(device), (g, model, predictor)) g, model, predictor = map(lambda x: x.to(device), (g, model, predictor))
evaluator = Evaluator(name=dataset.name) evaluator = Evaluator(name=dataset.name)
loggers = { loggers = {
'Hits@20': Logger(args.runs, args), "Hits@20": Logger(args.runs, args),
'Hits@50': Logger(args.runs, args), "Hits@50": Logger(args.runs, args),
'Hits@100': Logger(args.runs, args), "Hits@100": Logger(args.runs, args),
} }
for run in range(args.runs): for run in range(args.runs):
model.reset_parameters() model.reset_parameters()
predictor.reset_parameters() predictor.reset_parameters()
if dataset.name == 'ogbl-ddi': if dataset.name == "ogbl-ddi":
torch.nn.init.xavier_uniform_(emb.weight) torch.nn.init.xavier_uniform_(emb.weight)
g.ndata['feat'] = emb.weight g.ndata["feat"] = emb.weight
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
list(model.parameters()) + list(predictor.parameters()) + ( list(model.parameters())
list(emb.parameters()) if dataset.name == 'ogbl-ddi' else [] + list(predictor.parameters())
), + (list(emb.parameters()) if dataset.name == "ogbl-ddi" else []),
lr=args.lr) lr=args.lr,
)
for epoch in range(1, 1 + args.epochs): for epoch in range(1, 1 + args.epochs):
loss = train(model, predictor, g, g.ndata['feat'], split_edge, optimizer, loss = train(
args.batch_size) model,
predictor,
g,
g.ndata["feat"],
split_edge,
optimizer,
args.batch_size,
)
if epoch % args.eval_steps == 0: if epoch % args.eval_steps == 0:
results = test(model, predictor, g, g.ndata['feat'], split_edge, evaluator, results = test(
args.batch_size) model,
predictor,
g,
g.ndata["feat"],
split_edge,
evaluator,
args.batch_size,
)
for key, result in results.items(): for key, result in results.items():
loggers[key].add_result(run, result) loggers[key].add_result(run, result)
train_hits, valid_hits, test_hits = result train_hits, valid_hits, test_hits = result
print(key) print(key)
print(f'Run: {run + 1:02d}, ' print(
f'Epoch: {epoch:02d}, ' f"Run: {run + 1:02d}, "
f'Loss: {loss:.4f}, ' f"Epoch: {epoch:02d}, "
f'Train: {100 * train_hits:.2f}%, ' f"Loss: {loss:.4f}, "
f'Valid: {100 * valid_hits:.2f}%, ' f"Train: {100 * train_hits:.2f}%, "
f'Test: {100 * test_hits:.2f}%') f"Valid: {100 * valid_hits:.2f}%, "
print('---') f"Test: {100 * test_hits:.2f}%"
)
print("---")
for key in loggers.keys(): for key in loggers.keys():
print(key) print(key)
......
...@@ -4,9 +4,10 @@ import glob ...@@ -4,9 +4,10 @@ import glob
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl import function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from dgl import function as fn
device = None device = None
dataset = "ogbn-arxiv" dataset = "ogbn-arxiv"
...@@ -20,7 +21,11 @@ def load_data(dataset): ...@@ -20,7 +21,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
n_node_feats = graph.ndata["feat"].shape[1] n_node_feats = graph.ndata["feat"].shape[1]
...@@ -46,7 +51,9 @@ def preprocess(graph): ...@@ -46,7 +51,9 @@ def preprocess(graph):
return graph return graph
def general_outcome_correlation(graph, y0, n_prop=50, alpha=0.8, use_norm=False, post_step=None): def general_outcome_correlation(
graph, y0, n_prop=50, alpha=0.8, use_norm=False, post_step=None
):
with graph.local_scope(): with graph.local_scope():
y = y0 y = y0
for _ in range(n_prop): for _ in range(n_prop):
...@@ -94,7 +101,9 @@ def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator): ...@@ -94,7 +101,9 @@ def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator):
# dy = torch.zeros(graph.number_of_nodes(), n_classes, device=device) # dy = torch.zeros(graph.number_of_nodes(), n_classes, device=device)
# dy[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1) - pred[train_idx] # dy[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1) - pred[train_idx]
_train_acc, val_acc, test_acc = evaluate(labels, y, train_idx, val_idx, test_idx, evaluator_wrapper) _train_acc, val_acc, test_acc = evaluate(
labels, y, train_idx, val_idx, test_idx, evaluator_wrapper
)
# print("train acc:", _train_acc) # print("train acc:", _train_acc)
print("original val acc:", val_acc) print("original val acc:", val_acc)
...@@ -110,10 +119,16 @@ def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator): ...@@ -110,10 +119,16 @@ def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator):
# y = y + args.alpha2 * smoothed_dy # .clamp(0, 1) # y = y + args.alpha2 * smoothed_dy # .clamp(0, 1)
smoothed_y = general_outcome_correlation( smoothed_y = general_outcome_correlation(
graph, y, alpha=args.alpha, use_norm=args.use_norm, post_step=lambda x: x.clamp(0, 1) graph,
y,
alpha=args.alpha,
use_norm=args.use_norm,
post_step=lambda x: x.clamp(0, 1),
) )
_train_acc, val_acc, test_acc = evaluate(labels, smoothed_y, train_idx, val_idx, test_idx, evaluator_wrapper) _train_acc, val_acc, test_acc = evaluate(
labels, smoothed_y, train_idx, val_idx, test_idx, evaluator_wrapper
)
# print("train acc:", _train_acc) # print("train acc:", _train_acc)
print("val acc:", val_acc) print("val acc:", val_acc)
...@@ -126,11 +141,24 @@ def main(): ...@@ -126,11 +141,24 @@ def main():
global device global device
argparser = argparse.ArgumentParser(description="implementation of C&S)") argparser = argparse.ArgumentParser(description="implementation of C&S)")
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides --gpu.",
)
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.") argparser.add_argument(
"--use-norm",
action="store_true",
help="Use symmetrically normalized adjacency matrix.",
)
argparser.add_argument("--alpha", type=float, default=0.6, help="alpha") argparser.add_argument("--alpha", type=float, default=0.6, help="alpha")
argparser.add_argument("--pred-files", type=str, default="./output/*.pt", help="address of prediction files") argparser.add_argument(
"--pred-files",
type=str,
default="./output/*.pt",
help="address of prediction files",
)
args = argparser.parse_args() args = argparser.parse_args()
if args.cpu: if args.cpu:
...@@ -152,7 +180,9 @@ def main(): ...@@ -152,7 +180,9 @@ def main():
for pred_file in glob.iglob(args.pred_files): for pred_file in glob.iglob(args.pred_files):
print("load:", pred_file) print("load:", pred_file)
pred = torch.load(pred_file) pred = torch.load(pred_file)
val_acc, test_acc = run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator) val_acc, test_acc = run(
args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator
)
val_accs.append(val_acc) val_accs.append(val_acc)
test_accs.append(test_acc) test_accs.append(test_acc)
......
...@@ -7,16 +7,16 @@ import os ...@@ -7,16 +7,16 @@ import os
import random import random
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GAT
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from models import GAT import dgl
epsilon = 1 - math.log(2) epsilon = 1 - math.log(2)
...@@ -44,7 +44,11 @@ def load_data(dataset): ...@@ -44,7 +44,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
n_node_feats = graph.ndata["feat"].shape[1] n_node_feats = graph.ndata["feat"].shape[1]
...@@ -113,7 +117,17 @@ def adjust_learning_rate(optimizer, lr, epoch): ...@@ -113,7 +117,17 @@ def adjust_learning_rate(optimizer, lr, epoch):
param_group["lr"] = lr * epoch / 50 param_group["lr"] = lr * epoch / 50
def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator): def train(
args,
model,
graph,
labels,
train_idx,
val_idx,
test_idx,
optimizer,
evaluator,
):
model.train() model.train()
feat = graph.ndata["feat"] feat = graph.ndata["feat"]
...@@ -138,7 +152,9 @@ def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, e ...@@ -138,7 +152,9 @@ def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, e
for _ in range(args.n_label_iters): for _ in range(args.n_label_iters):
pred = pred.detach() pred = pred.detach()
torch.cuda.empty_cache() torch.cuda.empty_cache()
feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1) feat[unlabel_idx, -n_classes:] = F.softmax(
pred[unlabel_idx], dim=-1
)
pred = model(graph, feat) pred = model(graph, feat)
loss = custom_loss_function(pred[train_pred_idx], labels[train_pred_idx]) loss = custom_loss_function(pred[train_pred_idx], labels[train_pred_idx])
...@@ -149,7 +165,9 @@ def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, e ...@@ -149,7 +165,9 @@ def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, e
@torch.no_grad() @torch.no_grad()
def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator): def evaluate(
args, model, graph, labels, train_idx, val_idx, test_idx, evaluator
):
model.eval() model.eval()
feat = graph.ndata["feat"] feat = graph.ndata["feat"]
...@@ -162,7 +180,9 @@ def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator ...@@ -162,7 +180,9 @@ def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator
if args.n_label_iters > 0: if args.n_label_iters > 0:
unlabel_idx = torch.cat([val_idx, test_idx]) unlabel_idx = torch.cat([val_idx, test_idx])
for _ in range(args.n_label_iters): for _ in range(args.n_label_iters):
feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1) feat[unlabel_idx, -n_classes:] = F.softmax(
pred[unlabel_idx], dim=-1
)
pred = model(graph, feat) pred = model(graph, feat)
train_loss = custom_loss_function(pred[train_idx], labels[train_idx]) train_loss = custom_loss_function(pred[train_idx], labels[train_idx])
...@@ -180,14 +200,18 @@ def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator ...@@ -180,14 +200,18 @@ def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator
) )
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): def run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
evaluator_wrapper = lambda pred, labels: evaluator.eval( evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"] )["acc"]
# define model and optimizer # define model and optimizer
model = gen_model(args).to(device) model = gen_model(args).to(device)
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.RMSprop(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
# training loop # training loop
total_time = 0 total_time = 0
...@@ -202,10 +226,35 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -202,10 +226,35 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
adjust_learning_rate(optimizer, args.lr, epoch) adjust_learning_rate(optimizer, args.lr, epoch)
acc, loss = train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator_wrapper) acc, loss = train(
args,
model,
graph,
labels,
train_idx,
val_idx,
test_idx,
optimizer,
evaluator_wrapper,
)
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred = evaluate( (
args, model, graph, labels, train_idx, val_idx, test_idx, evaluator_wrapper train_acc,
val_acc,
test_acc,
train_loss,
val_loss,
test_loss,
pred,
) = evaluate(
args,
model,
graph,
labels,
train_idx,
val_idx,
test_idx,
evaluator_wrapper,
) )
toc = time.time() toc = time.time()
...@@ -226,8 +275,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -226,8 +275,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
) )
for l, e in zip( for l, e in zip(
[accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses], [
[acc, train_acc, val_acc, test_acc, loss, train_loss, val_loss, test_loss], accs,
train_accs,
val_accs,
test_accs,
losses,
train_losses,
val_losses,
test_losses,
],
[
acc,
train_acc,
val_acc,
test_acc,
loss,
train_loss,
val_loss,
test_loss,
],
): ):
l.append(e) l.append(e)
...@@ -242,7 +309,10 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -242,7 +309,10 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101)) ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]): for y, label in zip(
[accs, train_accs, val_accs, test_accs],
["acc", "train acc", "val acc", "test acc"],
):
plt.plot(range(args.n_epochs), y, label=label, linewidth=1) plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
...@@ -259,7 +329,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -259,7 +329,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip( for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] [losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
): ):
plt.plot(range(args.n_epochs), y, label=label, linewidth=1) plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
...@@ -288,36 +359,84 @@ def main(): ...@@ -288,36 +359,84 @@ def main():
global device, n_node_feats, n_classes, epsilon global device, n_node_feats, n_classes, epsilon
argparser = argparse.ArgumentParser( argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter "GAT implementation on ogbn-arxiv",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides --gpu.",
) )
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--seed", type=int, default=0, help="seed") argparser.add_argument("--seed", type=int, default=0, help="seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=2000, help="number of epochs")
argparser.add_argument( argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features." "--n-runs", type=int, default=10, help="running times"
)
argparser.add_argument(
"--n-epochs", type=int, default=2000, help="number of epochs"
)
argparser.add_argument(
"--use-labels",
action="store_true",
help="Use labels in the training set as input features.",
)
argparser.add_argument(
"--n-label-iters",
type=int,
default=0,
help="number of label iterations",
)
argparser.add_argument(
"--mask-rate", type=float, default=0.5, help="mask rate"
)
argparser.add_argument(
"--no-attn-dst", action="store_true", help="Don't use attn_dst."
)
argparser.add_argument(
"--use-norm",
action="store_true",
help="Use symmetrically normalized adjacency matrix.",
)
argparser.add_argument(
"--lr", type=float, default=0.002, help="learning rate"
)
argparser.add_argument(
"--n-layers", type=int, default=3, help="number of layers"
)
argparser.add_argument(
"--n-heads", type=int, default=3, help="number of heads"
)
argparser.add_argument(
"--n-hidden", type=int, default=250, help="number of hidden units"
)
argparser.add_argument(
"--dropout", type=float, default=0.75, help="dropout rate"
)
argparser.add_argument(
"--input-drop", type=float, default=0.1, help="input drop rate"
)
argparser.add_argument(
"--attn-drop", type=float, default=0.0, help="attention drop rate"
)
argparser.add_argument(
"--edge-drop", type=float, default=0.0, help="edge drop rate"
) )
argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations")
argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate")
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
argparser.add_argument("--lr", type=float, default=0.002, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-heads", type=int, default=3, help="number of heads")
argparser.add_argument("--n-hidden", type=int, default=250, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.75, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention drop rate")
argparser.add_argument("--edge-drop", type=float, default=0.0, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay") argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs") argparser.add_argument(
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves") "--log-every", type=int, default=20, help="log every LOG_EVERY epochs"
argparser.add_argument("--save-pred", action="store_true", help="save final predictions") )
argparser.add_argument(
"--plot-curves", action="store_true", help="plot learning curves"
)
argparser.add_argument(
"--save-pred", action="store_true", help="save final predictions"
)
args = argparser.parse_args() args = argparser.parse_args()
if not args.use_labels and args.n_label_iters > 0: if not args.use_labels and args.n_label_iters > 0:
raise ValueError("'--use-labels' must be enabled when n_label_iters > 0") raise ValueError(
"'--use-labels' must be enabled when n_label_iters > 0"
)
if args.cpu: if args.cpu:
device = torch.device("cpu") device = torch.device("cpu")
...@@ -337,7 +456,9 @@ def main(): ...@@ -337,7 +456,9 @@ def main():
for i in range(args.n_runs): for i in range(args.n_runs):
seed(args.seed + i) seed(args.seed + i)
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1) val_acc, test_acc = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1
)
val_accs.append(val_acc) val_accs.append(val_acc)
test_accs.append(test_acc) test_accs.append(test_acc)
......
...@@ -11,9 +11,8 @@ import torch.nn.functional as F ...@@ -11,9 +11,8 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from models import GCN from models import GCN
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
device = None device = None
in_feats, n_classes = None, None in_feats, n_classes = None, None
...@@ -23,10 +22,24 @@ epsilon = 1 - math.log(2) ...@@ -23,10 +22,24 @@ epsilon = 1 - math.log(2)
def gen_model(args): def gen_model(args):
if args.use_labels: if args.use_labels:
model = GCN( model = GCN(
in_feats + n_classes, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear in_feats + n_classes,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout,
args.use_linear,
) )
else: else:
model = GCN(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear) model = GCN(
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout,
args.use_linear,
)
return model return model
...@@ -37,7 +50,9 @@ def cross_entropy(x, labels): ...@@ -37,7 +50,9 @@ def cross_entropy(x, labels):
def compute_acc(pred, labels, evaluator): def compute_acc(pred, labels, evaluator):
return evaluator.eval({"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels})["acc"] return evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
def add_labels(feat, labels, idx): def add_labels(feat, labels, idx):
...@@ -81,7 +96,9 @@ def train(model, graph, labels, train_idx, optimizer, use_labels): ...@@ -81,7 +96,9 @@ def train(model, graph, labels, train_idx, optimizer, use_labels):
@th.no_grad() @th.no_grad()
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator): def evaluate(
model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator
):
model.eval() model.eval()
feat = graph.ndata["feat"] feat = graph.ndata["feat"]
...@@ -104,14 +121,23 @@ def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, eva ...@@ -104,14 +121,23 @@ def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, eva
) )
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): def run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
# define model and optimizer # define model and optimizer
model = gen_model(args) model = gen_model(args)
model = model.to(device) model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=100, verbose=True, min_lr=1e-3 optimizer,
mode="min",
factor=0.5,
patience=100,
verbose=True,
min_lr=1e-3,
) )
# training loop # training loop
...@@ -126,11 +152,27 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -126,11 +152,27 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
adjust_learning_rate(optimizer, args.lr, epoch) adjust_learning_rate(optimizer, args.lr, epoch)
loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels) loss, pred = train(
model, graph, labels, train_idx, optimizer, args.use_labels
)
acc = compute_acc(pred[train_idx], labels[train_idx], evaluator) acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate( (
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, evaluator train_acc,
val_acc,
test_acc,
train_loss,
val_loss,
test_loss,
) = evaluate(
model,
graph,
labels,
train_idx,
val_idx,
test_idx,
args.use_labels,
evaluator,
) )
lr_scheduler.step(loss) lr_scheduler.step(loss)
...@@ -152,8 +194,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -152,8 +194,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
) )
for l, e in zip( for l, e in zip(
[accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses], [
[acc, train_acc, val_acc, test_acc, loss, train_loss, val_loss, test_loss], accs,
train_accs,
val_accs,
test_accs,
losses,
train_losses,
val_losses,
test_losses,
],
[
acc,
train_acc,
val_acc,
test_acc,
loss,
train_loss,
val_loss,
test_loss,
],
): ):
l.append(e) l.append(e)
...@@ -167,7 +227,10 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -167,7 +227,10 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101)) ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]): for y, label in zip(
[accs, train_accs, val_accs, test_accs],
["acc", "train acc", "val acc", "test acc"],
):
plt.plot(range(args.n_epochs), y, label=label) plt.plot(range(args.n_epochs), y, label=label)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
...@@ -184,7 +247,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -184,7 +247,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip( for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] [losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
): ):
plt.plot(range(args.n_epochs), y, label=label) plt.plot(range(args.n_epochs), y, label=label)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
...@@ -202,28 +266,57 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -202,28 +266,57 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
def count_parameters(args): def count_parameters(args):
model = gen_model(args) model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) return sum(
[np.prod(p.size()) for p in model.parameters() if p.requires_grad]
)
def main(): def main():
global device, in_feats, n_classes global device, in_feats, n_classes
argparser = argparse.ArgumentParser("GCN on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter) argparser = argparse.ArgumentParser(
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") "GCN on OGBN-Arxiv",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides --gpu.",
)
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1000, help="number of epochs")
argparser.add_argument( argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features." "--n-runs", type=int, default=10, help="running times"
)
argparser.add_argument(
"--n-epochs", type=int, default=1000, help="number of epochs"
)
argparser.add_argument(
"--use-labels",
action="store_true",
help="Use labels in the training set as input features.",
)
argparser.add_argument(
"--use-linear", action="store_true", help="Use linear layer."
)
argparser.add_argument(
"--lr", type=float, default=0.005, help="learning rate"
)
argparser.add_argument(
"--n-layers", type=int, default=3, help="number of layers"
)
argparser.add_argument(
"--n-hidden", type=int, default=256, help="number of hidden units"
)
argparser.add_argument(
"--dropout", type=float, default=0.5, help="dropout rate"
) )
argparser.add_argument("--use-linear", action="store_true", help="Use linear layer.")
argparser.add_argument("--lr", type=float, default=0.005, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=256, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay") argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs") argparser.add_argument(
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves") "--log-every", type=int, default=20, help="log every LOG_EVERY epochs"
)
argparser.add_argument(
"--plot-curves", action="store_true", help="plot learning curves"
)
args = argparser.parse_args() args = argparser.parse_args()
if args.cpu: if args.cpu:
...@@ -236,7 +329,11 @@ def main(): ...@@ -236,7 +329,11 @@ def main():
evaluator = Evaluator(name="ogbn-arxiv") evaluator = Evaluator(name="ogbn-arxiv")
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
# add reverse edges # add reverse edges
...@@ -263,7 +360,9 @@ def main(): ...@@ -263,7 +360,9 @@ def main():
test_accs = [] test_accs = []
for i in range(args.n_runs): for i in range(args.n_runs):
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i) val_acc, test_acc = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i
)
val_accs.append(val_acc) val_accs.append(val_acc)
test_accs.append(test_acc) test_accs.append(test_acc)
......
import dgl.nn.pytorch as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.nn.pytorch as dglnn
from dgl import function as fn from dgl import function as fn
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair from dgl.utils import expand_as_pair
...@@ -42,7 +43,16 @@ class ElementWiseLinear(nn.Module): ...@@ -42,7 +43,16 @@ class ElementWiseLinear(nn.Module):
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, use_linear): def __init__(
self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
use_linear,
):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
...@@ -59,7 +69,9 @@ class GCN(nn.Module): ...@@ -59,7 +69,9 @@ class GCN(nn.Module):
out_hidden = n_hidden if i < n_layers - 1 else n_classes out_hidden = n_hidden if i < n_layers - 1 else n_classes
bias = i == n_layers - 1 bias = i == n_layers - 1
self.convs.append(dglnn.GraphConv(in_hidden, out_hidden, "both", bias=bias)) self.convs.append(
dglnn.GraphConv(in_hidden, out_hidden, "both", bias=bias)
)
if use_linear: if use_linear:
self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False)) self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))
if i < n_layers - 1: if i < n_layers - 1:
...@@ -113,13 +125,23 @@ class GATConv(nn.Module): ...@@ -113,13 +125,23 @@ class GATConv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
self._use_symmetric_norm = use_symmetric_norm self._use_symmetric_norm = use_symmetric_norm
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) self.fc_src = nn.Linear(
self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False) self._in_src_feats, out_feats * num_heads, bias=False
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False
)
else: else:
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) self.fc = nn.Linear(
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) self._in_src_feats, out_feats * num_heads, bias=False
)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, num_heads, out_feats))
)
if use_attn_dst: if use_attn_dst:
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, num_heads, out_feats))
)
else: else:
self.register_buffer("attn_r", None) self.register_buffer("attn_r", None)
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
...@@ -127,7 +149,9 @@ class GATConv(nn.Module): ...@@ -127,7 +149,9 @@ class GATConv(nn.Module):
self.edge_drop = edge_drop self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual: if residual:
self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False) self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False
)
else: else:
self.register_buffer("res_fc", None) self.register_buffer("res_fc", None)
self.reset_parameters() self.reset_parameters()
...@@ -161,12 +185,18 @@ class GATConv(nn.Module): ...@@ -161,12 +185,18 @@ class GATConv(nn.Module):
if not hasattr(self, "fc_src"): if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc self.fc_src, self.fc_dst = self.fc, self.fc
feat_src, feat_dst = h_src, h_dst feat_src, feat_dst = h_src, h_dst
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else: else:
h_src = self.feat_drop(feat) h_src = self.feat_drop(feat)
feat_src = h_src feat_src = h_src
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc(h_src).view(
-1, self._num_heads, self._out_feats
)
if graph.is_block: if graph.is_block:
h_dst = h_src[: graph.number_of_dst_nodes()] h_dst = h_src[: graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
...@@ -207,7 +237,9 @@ class GATConv(nn.Module): ...@@ -207,7 +237,9 @@ class GATConv(nn.Module):
bound = int(graph.number_of_edges() * self.edge_drop) bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:] eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e) graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids)) graph.edata["a"][eids] = self.attn_drop(
edge_softmax(graph, e[eids], eids=eids)
)
else: else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
...@@ -224,7 +256,9 @@ class GATConv(nn.Module): ...@@ -224,7 +256,9 @@ class GATConv(nn.Module):
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval rst = rst + resval
# activation # activation
...@@ -282,7 +316,9 @@ class GAT(nn.Module): ...@@ -282,7 +316,9 @@ class GAT(nn.Module):
if i < n_layers - 1: if i < n_layers - 1:
self.norms.append(nn.BatchNorm1d(out_channels * out_hidden)) self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))
self.bias_last = ElementWiseLinear(n_classes, weight=False, bias=True, inplace=True) self.bias_last = ElementWiseLinear(
n_classes, weight=False, bias=True, inplace=True
)
self.input_drop = nn.Dropout(input_drop) self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
......
import argparse import argparse
import itertools import itertools
from tqdm import tqdm
import dgl
import dgl.nn as dglnn
from dgl.nn import HeteroEmbedding
from dgl import Compose, AddReverse, ToSimple
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from tqdm import tqdm
import dgl
import dgl.nn as dglnn
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding
def prepare_data(args): def prepare_data(args):
dataset = DglNodePropPredDataset(name="ogbn-mag") dataset = DglNodePropPredDataset(name="ogbn-mag")
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
# graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks) # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
g, labels = dataset[0] g, labels = dataset[0]
labels = labels['paper'].flatten() labels = labels["paper"].flatten()
transform = Compose([ToSimple(), AddReverse()]) transform = Compose([ToSimple(), AddReverse()])
g = transform(g) g = transform(g)
...@@ -28,34 +30,38 @@ def prepare_data(args): ...@@ -28,34 +30,38 @@ def prepare_data(args):
# train sampler # train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20]) sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
train_loader = dgl.dataloading.DataLoader( train_loader = dgl.dataloading.DataLoader(
g, split_idx['train'], sampler, g,
batch_size=1024, shuffle=True, num_workers=0) split_idx["train"],
sampler,
batch_size=1024,
shuffle=True,
num_workers=0,
)
return g, labels, dataset.num_classes, split_idx, logger, train_loader return g, labels, dataset.num_classes, split_idx, logger, train_loader
def extract_embed(node_embed, input_nodes): def extract_embed(node_embed, input_nodes):
emb = node_embed({ emb = node_embed(
ntype: input_nodes[ntype] for ntype in input_nodes if ntype != 'paper' {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
}) )
return emb return emb
def rel_graph_embed(graph, embed_size): def rel_graph_embed(graph, embed_size):
node_num = {} node_num = {}
for ntype in graph.ntypes: for ntype in graph.ntypes:
if ntype == 'paper': if ntype == "paper":
continue continue
node_num[ntype] = graph.num_nodes(ntype) node_num[ntype] = graph.num_nodes(ntype)
embeds = HeteroEmbedding(node_num, embed_size) embeds = HeteroEmbedding(node_num, embed_size)
return embeds return embeds
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
def __init__(self, def __init__(
in_feat, self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
out_feat, ):
ntypes,
rel_names,
activation=None,
dropout=0.0):
super(RelGraphConvLayer, self).__init__() super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -63,21 +69,29 @@ class RelGraphConvLayer(nn.Module): ...@@ -63,21 +69,29 @@ class RelGraphConvLayer(nn.Module):
self.rel_names = rel_names self.rel_names = rel_names
self.activation = activation self.activation = activation
self.conv = dglnn.HeteroGraphConv({ self.conv = dglnn.HeteroGraphConv(
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) {
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names for rel in rel_names
}) }
)
self.weight = nn.ModuleDict({ self.weight = nn.ModuleDict(
rel_name: nn.Linear(in_feat, out_feat, bias=False) {
for rel_name in self.rel_names rel_name: nn.Linear(in_feat, out_feat, bias=False)
}) for rel_name in self.rel_names
}
)
# weight for self loop # weight for self loop
self.loop_weights = nn.ModuleDict({ self.loop_weights = nn.ModuleDict(
ntype: nn.Linear(in_feat, out_feat, bias=True) {
for ntype in self.ntypes ntype: nn.Linear(in_feat, out_feat, bias=True)
}) for ntype in self.ntypes
}
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.reset_parameters() self.reset_parameters()
...@@ -104,10 +118,14 @@ class RelGraphConvLayer(nn.Module): ...@@ -104,10 +118,14 @@ class RelGraphConvLayer(nn.Module):
New node features for each node type. New node features for each node type.
""" """
g = g.local_var() g = g.local_var()
wdict = {rel_name: {'weight': self.weight[rel_name].weight.T} wdict = {
for rel_name in self.rel_names} rel_name: {"weight": self.weight[rel_name].weight.T}
for rel_name in self.rel_names
}
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} inputs_dst = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
hs = self.conv(g, inputs, mod_kwargs=wdict) hs = self.conv(g, inputs, mod_kwargs=wdict)
...@@ -117,7 +135,8 @@ class RelGraphConvLayer(nn.Module): ...@@ -117,7 +135,8 @@ class RelGraphConvLayer(nn.Module):
h = self.activation(h) h = self.activation(h)
return self.dropout(h) return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()} return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
def __init__(self, g, in_dim, out_dim): def __init__(self, g, in_dim, out_dim):
...@@ -131,14 +150,27 @@ class EntityClassify(nn.Module): ...@@ -131,14 +150,27 @@ class EntityClassify(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConvLayer( self.layers.append(
self.in_dim, self.h_dim, g.ntypes, self.rel_names, RelGraphConvLayer(
activation=F.relu, dropout=self.dropout)) self.in_dim,
self.h_dim,
g.ntypes,
self.rel_names,
activation=F.relu,
dropout=self.dropout,
)
)
# h2o # h2o
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.out_dim, g.ntypes, self.rel_names, RelGraphConvLayer(
activation=None)) self.h_dim,
self.out_dim,
g.ntypes,
self.rel_names,
activation=None,
)
)
def reset_parameters(self): def reset_parameters(self):
for layer in self.layers: for layer in self.layers:
...@@ -149,6 +181,7 @@ class EntityClassify(nn.Module): ...@@ -149,6 +181,7 @@ class EntityClassify(nn.Module):
h = layer(block, h) h = layer(block, h)
return h return h
class Logger(object): class Logger(object):
r""" r"""
This class was taken directly from the PyG implementation and can be found This class was taken directly from the PyG implementation and can be found
...@@ -156,6 +189,7 @@ class Logger(object): ...@@ -156,6 +189,7 @@ class Logger(object):
This was done to ensure that performance was measured in precisely the same way This was done to ensure that performance was measured in precisely the same way
""" """
def __init__(self, runs): def __init__(self, runs):
self.results = [[] for _ in range(runs)] self.results = [[] for _ in range(runs)]
...@@ -168,11 +202,11 @@ class Logger(object): ...@@ -168,11 +202,11 @@ class Logger(object):
if run is not None: if run is not None:
result = 100 * th.tensor(self.results[run]) result = 100 * th.tensor(self.results[run])
argmax = result[:, 1].argmax().item() argmax = result[:, 1].argmax().item()
print(f'Run {run + 1:02d}:') print(f"Run {run + 1:02d}:")
print(f'Highest Train: {result[:, 0].max():.2f}') print(f"Highest Train: {result[:, 0].max():.2f}")
print(f'Highest Valid: {result[:, 1].max():.2f}') print(f"Highest Valid: {result[:, 1].max():.2f}")
print(f' Final Train: {result[argmax, 0]:.2f}') print(f" Final Train: {result[argmax, 0]:.2f}")
print(f' Final Test: {result[argmax, 2]:.2f}') print(f" Final Test: {result[argmax, 2]:.2f}")
else: else:
result = 100 * th.tensor(self.results) result = 100 * th.tensor(self.results)
...@@ -186,39 +220,54 @@ class Logger(object): ...@@ -186,39 +220,54 @@ class Logger(object):
best_result = th.tensor(best_results) best_result = th.tensor(best_results)
print(f'All runs:') print(f"All runs:")
r = best_result[:, 0] r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 1] r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 2] r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3] r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
def train(g, model, node_embed, optimizer, train_loader, split_idx,
labels, logger, device, run): def train(
g,
model,
node_embed,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
):
print("start training...") print("start training...")
category = 'paper' category = "paper"
for epoch in range(3): for epoch in range(3):
num_train = split_idx['train'][category].shape[0] num_train = split_idx["train"][category].shape[0]
pbar = tqdm(total=num_train) pbar = tqdm(total=num_train)
pbar.set_description(f'Epoch {epoch:02d}') pbar.set_description(f"Epoch {epoch:02d}")
model.train() model.train()
total_loss = 0 total_loss = 0
for input_nodes, seeds, blocks in train_loader: for input_nodes, seeds, blocks in train_loader:
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] # we only predict the nodes with type "category" seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_size = seeds.shape[0] batch_size = seeds.shape[0]
emb = extract_embed(node_embed, input_nodes) emb = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features # Add the batch's raw "paper" features
emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]}) emb.update(
{"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]}
)
emb = {k : e.to(device) for k, e in emb.items()} emb = {k: e.to(device) for k, e in emb.items()}
lbl = labels[seeds].to(device) lbl = labels[seeds].to(device)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -238,41 +287,51 @@ def train(g, model, node_embed, optimizer, train_loader, split_idx, ...@@ -238,41 +287,51 @@ def train(g, model, node_embed, optimizer, train_loader, split_idx,
result = test(g, model, node_embed, labels, device, split_idx) result = test(g, model, node_embed, labels, device, split_idx)
logger.add_result(run, result) logger.add_result(run, result)
train_acc, valid_acc, test_acc = result train_acc, valid_acc, test_acc = result
print(f'Run: {run + 1:02d}, ' print(
f'Epoch: {epoch +1 :02d}, ' f"Run: {run + 1:02d}, "
f'Loss: {loss:.4f}, ' f"Epoch: {epoch +1 :02d}, "
f'Train: {100 * train_acc:.2f}%, ' f"Loss: {loss:.4f}, "
f'Valid: {100 * valid_acc:.2f}%, ' f"Train: {100 * train_acc:.2f}%, "
f'Test: {100 * test_acc:.2f}%') f"Valid: {100 * valid_acc:.2f}%, "
f"Test: {100 * test_acc:.2f}%"
)
return logger return logger
@th.no_grad() @th.no_grad()
def test(g, model, node_embed, y_true, device, split_idx): def test(g, model, node_embed, y_true, device, split_idx):
model.eval() model.eval()
category = 'paper' category = "paper"
evaluator = Evaluator(name='ogbn-mag') evaluator = Evaluator(name="ogbn-mag")
# 2 GNN layers # 2 GNN layers
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
loader = dgl.dataloading.DataLoader( loader = dgl.dataloading.DataLoader(
g, {'paper': th.arange(g.num_nodes('paper'))}, sampler, g,
batch_size=16384, shuffle=False, num_workers=0) {"paper": th.arange(g.num_nodes("paper"))},
sampler,
batch_size=16384,
shuffle=False,
num_workers=0,
)
pbar = tqdm(total=y_true.size(0)) pbar = tqdm(total=y_true.size(0))
pbar.set_description(f'Inference') pbar.set_description(f"Inference")
y_hats = list() y_hats = list()
for input_nodes, seeds, blocks in loader: for input_nodes, seeds, blocks in loader:
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] # we only predict the nodes with type "category" seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_size = seeds.shape[0] batch_size = seeds.shape[0]
emb = extract_embed(node_embed, input_nodes) emb = extract_embed(node_embed, input_nodes)
# Get the batch's raw "paper" features # Get the batch's raw "paper" features
emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]}) emb.update({"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]})
emb = {k : e.to(device) for k, e in emb.items()} emb = {k: e.to(device) for k, e in emb.items()}
logits = model(emb, blocks)[category] logits = model(emb, blocks)[category]
y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True) y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
...@@ -285,31 +344,42 @@ def test(g, model, node_embed, y_true, device, split_idx): ...@@ -285,31 +344,42 @@ def test(g, model, node_embed, y_true, device, split_idx):
y_pred = th.cat(y_hats, dim=0) y_pred = th.cat(y_hats, dim=0)
y_true = th.unsqueeze(y_true, 1) y_true = th.unsqueeze(y_true, 1)
train_acc = evaluator.eval({ train_acc = evaluator.eval(
'y_true': y_true[split_idx['train']['paper']], {
'y_pred': y_pred[split_idx['train']['paper']], "y_true": y_true[split_idx["train"]["paper"]],
})['acc'] "y_pred": y_pred[split_idx["train"]["paper"]],
valid_acc = evaluator.eval({ }
'y_true': y_true[split_idx['valid']['paper']], )["acc"]
'y_pred': y_pred[split_idx['valid']['paper']], valid_acc = evaluator.eval(
})['acc'] {
test_acc = evaluator.eval({ "y_true": y_true[split_idx["valid"]["paper"]],
'y_true': y_true[split_idx['test']['paper']], "y_pred": y_pred[split_idx["valid"]["paper"]],
'y_pred': y_pred[split_idx['test']['paper']], }
})['acc'] )["acc"]
test_acc = evaluator.eval(
{
"y_true": y_true[split_idx["test"]["paper"]],
"y_pred": y_pred[split_idx["test"]["paper"]],
}
)["acc"]
return train_acc, valid_acc, test_acc return train_acc, valid_acc, test_acc
def main(args): def main(args):
device = f'cuda:0' if th.cuda.is_available() else 'cpu' device = f"cuda:0" if th.cuda.is_available() else "cpu"
g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args) g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args)
embed_layer = rel_graph_embed(g, 128) embed_layer = rel_graph_embed(g, 128)
model = EntityClassify(g, 128, num_classes).to(device) model = EntityClassify(g, 128, num_classes).to(device)
print(f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}") print(
print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}") f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
)
print(
f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
)
for run in range(args.runs): for run in range(args.runs):
...@@ -317,19 +387,32 @@ def main(args): ...@@ -317,19 +387,32 @@ def main(args):
model.reset_parameters() model.reset_parameters()
# optimizer # optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters()) all_params = itertools.chain(
model.parameters(), embed_layer.parameters()
)
optimizer = th.optim.Adam(all_params, lr=0.01) optimizer = th.optim.Adam(all_params, lr=0.01)
logger = train(g, model, embed_layer, optimizer, train_loader, split_idx, logger = train(
labels, logger, device, run) g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
logger.print_statistics(run) logger.print_statistics(run)
print("Final performance: ") print("Final performance: ")
logger.print_statistics() logger.print_statistics()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') if __name__ == "__main__":
parser.add_argument('--runs', type=int, default=10) parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument("--runs", type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -7,21 +7,24 @@ import random ...@@ -7,21 +7,24 @@ import random
import time import time
from collections import OrderedDict from collections import OrderedDict
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GAT
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from models import GAT import dgl
import dgl.function as fn
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
MultiLayerNeighborSampler,
)
epsilon = 1 - math.log(2) epsilon = 1 - math.log(2)
...@@ -46,7 +49,11 @@ def load_data(dataset): ...@@ -46,7 +49,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
graph.ndata["labels"] = labels graph.ndata["labels"] = labels
...@@ -61,10 +68,14 @@ def preprocess(graph, labels, train_idx): ...@@ -61,10 +68,14 @@ def preprocess(graph, labels, train_idx):
# graph = graph.remove_self_loop().add_self_loop() # graph = graph.remove_self_loop().add_self_loop()
n_node_feats = graph.ndata["feat"].shape[-1] n_node_feats = graph.ndata["feat"].shape[-1]
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes) graph.ndata["train_labels_onehot"] = torch.zeros(
graph.number_of_nodes(), n_classes
)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1 graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["is_train"] = torch.zeros(graph.number_of_nodes(), dtype=torch.bool) graph.ndata["is_train"] = torch.zeros(
graph.number_of_nodes(), dtype=torch.bool
)
graph.ndata["is_train"][train_idx] = 1 graph.ndata["is_train"][train_idx] = 1
graph.create_formats_() graph.create_formats_()
...@@ -112,12 +123,18 @@ def add_soft_labels(graph, soft_labels): ...@@ -112,12 +123,18 @@ def add_soft_labels(graph, soft_labels):
def update_hard_labels(graph, idx=None): def update_hard_labels(graph, idx=None):
if idx is None: if idx is None:
idx = torch.arange(graph.srcdata["is_train"].shape[0])[graph.srcdata["is_train"]] idx = torch.arange(graph.srcdata["is_train"].shape[0])[
graph.srcdata["is_train"]
]
graph.srcdata["feat"][idx, -n_classes:] = graph.srcdata["train_labels_onehot"][idx] graph.srcdata["feat"][idx, -n_classes:] = graph.srcdata[
"train_labels_onehot"
][idx]
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator): def train(
args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator
):
model.train() model.train()
loss_sum, total = 0, 0 loss_sum, total = 0, 0
...@@ -133,10 +150,18 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -133,10 +150,18 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
if args.use_labels: if args.use_labels:
mask = torch.rand(new_train_idx.shape) < args.mask_rate mask = torch.rand(new_train_idx.shape) < args.mask_rate
train_labels_idx = torch.cat([new_train_idx[~mask], torch.arange(len(output_nodes), len(input_nodes))]) train_labels_idx = torch.cat(
[
new_train_idx[~mask],
torch.arange(len(output_nodes), len(input_nodes)),
]
)
train_pred_idx = new_train_idx[mask] train_pred_idx = new_train_idx[mask]
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1)) add_soft_labels(
subgraphs[0],
F.softmax(preds_old[input_nodes].to(device), dim=-1),
)
update_hard_labels(subgraphs[0], train_labels_idx) update_hard_labels(subgraphs[0], train_labels_idx)
else: else:
train_pred_idx = new_train_idx train_pred_idx = new_train_idx
...@@ -148,7 +173,10 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -148,7 +173,10 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
# NOTE: This is not a complete implementation of label reuse, since it is too expensive # NOTE: This is not a complete implementation of label reuse, since it is too expensive
# to predict the nodes in validation and test set during training time. # to predict the nodes in validation and test set during training time.
if it == args.n_label_iters: if it == args.n_label_iters:
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx]) loss = criterion(
pred[train_pred_idx],
subgraphs[-1].dstdata["labels"][train_pred_idx],
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -166,7 +194,17 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -166,7 +194,17 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
@torch.no_grad() @torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator): def evaluate(
args,
model,
dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator,
):
model.eval() model.eval()
# Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times. # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
...@@ -182,7 +220,10 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit ...@@ -182,7 +220,10 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit
subgraphs = [b.to(device) for b in subgraphs] subgraphs = [b.to(device) for b in subgraphs]
if args.use_labels: if args.use_labels:
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1)) add_soft_labels(
subgraphs[0],
F.softmax(preds_old[input_nodes].to(device), dim=-1),
)
update_hard_labels(subgraphs[0]) update_hard_labels(subgraphs[0])
pred = model(subgraphs, inference=True) pred = model(subgraphs, inference=True)
...@@ -209,7 +250,9 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit ...@@ -209,7 +250,9 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit
) )
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): def run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
evaluator_wrapper = lambda pred, labels: evaluator.eval( evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"] )["acc"]
...@@ -217,37 +260,52 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -217,37 +260,52 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
n_train_samples = train_idx.shape[0] n_train_samples = train_idx.shape[0]
train_batch_size = (n_train_samples + 29) // 30 train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler(
train_dataloader = DataLoader( [10 for _ in range(args.n_layers)]
)
train_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_size=train_batch_size, shuffle=True, batch_size=train_batch_size,
num_workers=4, shuffle=True,
num_workers=4,
) )
eval_batch_size = 32768 eval_batch_size = 32768
eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)]) eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)])
if args.estimation_mode: if args.estimation_mode:
test_idx_during_training = test_idx[torch.arange(start=0, end=len(test_idx), step=45)] test_idx_during_training = test_idx[
torch.arange(start=0, end=len(test_idx), step=45)
]
else: else:
test_idx_during_training = test_idx test_idx_during_training = test_idx
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]) eval_idx = torch.cat(
eval_dataloader = DataLoader( [train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]
)
eval_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
batch_size=eval_batch_size, shuffle=False, batch_size=eval_batch_size,
num_workers=4, shuffle=False,
num_workers=4,
) )
model = gen_model(args).to(device) model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4 optimizer,
mode="max",
factor=0.7,
patience=20,
verbose=True,
min_lr=1e-4,
) )
best_model_state_dict = None best_model_state_dict = None
...@@ -261,13 +319,33 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -261,13 +319,33 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for epoch in range(1, args.n_epochs + 1): for epoch in range(1, args.n_epochs + 1):
tic = time.time() tic = time.time()
score, loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) score, loss = train(
args,
model,
train_dataloader,
labels,
train_idx,
criterion,
optimizer,
evaluator_wrapper,
)
toc = time.time() toc = time.time()
total_time += toc - tic total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0: if (
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate( epoch == args.n_epochs
or epoch % args.eval_every == 0
or epoch % args.log_every == 0
):
(
train_score,
val_score,
test_score,
train_loss,
val_loss,
test_loss,
) = evaluate(
args, args,
model, model,
eval_dataloader, eval_dataloader,
...@@ -283,7 +361,9 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -283,7 +361,9 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
best_val_score = val_score best_val_score = val_score
final_test_score = test_score final_test_score = test_score
if args.estimation_mode: if args.estimation_mode:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()} best_model_state_dict = {
k: v.to("cpu") for k, v in model.state_dict().items()
}
if epoch == args.n_epochs or epoch % args.log_every == 0: if epoch == args.n_epochs or epoch % args.log_every == 0:
print( print(
...@@ -294,8 +374,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -294,8 +374,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
) )
for l, e in zip( for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses], [
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss], scores,
train_scores,
val_scores,
test_scores,
losses,
train_losses,
val_losses,
test_losses,
],
[
score,
train_score,
val_score,
test_score,
loss,
train_loss,
val_loss,
test_loss,
],
): ):
l.append(e) l.append(e)
...@@ -303,19 +401,30 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -303,19 +401,30 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.estimation_mode: if args.estimation_mode:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
batch_size=eval_batch_size, shuffle=False, batch_size=eval_batch_size,
num_workers=4, shuffle=False,
num_workers=4,
) )
final_test_score = evaluate( final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator_wrapper,
)[2] )[2]
print("*" * 50) print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}") print(
f"Best val score: {best_val_score}, Final test score: {final_test_score}"
)
print("*" * 50) print("*" * 50)
if args.plot_curves: if args.plot_curves:
...@@ -324,8 +433,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -324,8 +433,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101)) ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]): for y, label in zip(
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) [train_scores, val_scores, test_scores],
["train score", "val score", "test score"],
):
plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(10)) ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_major_locator(MultipleLocator(0.01))
...@@ -341,9 +458,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -341,9 +458,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip( for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] [losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
): ):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(10)) ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_major_locator(MultipleLocator(0.1))
...@@ -359,41 +482,87 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -359,41 +482,87 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
def count_parameters(args): def count_parameters(args):
model = gen_model(args) model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) return sum(
[np.prod(p.size()) for p in model.parameters() if p.requires_grad]
)
def main(): def main():
global device global device
argparser = argparse.ArgumentParser( argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-products", formatter_class=argparse.ArgumentDefaultsHelpFormatter "GAT implementation on ogbn-products",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides '--gpu'.",
) )
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="seed") argparser.add_argument("--seed", type=int, default=0, help="seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=250, help="number of epochs")
argparser.add_argument( argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features." "--n-runs", type=int, default=10, help="running times"
) )
argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations") argparser.add_argument(
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.") "--n-epochs", type=int, default=250, help="number of epochs"
argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate") )
argparser.add_argument("--n-heads", type=int, default=4, help="number of heads") argparser.add_argument(
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate") "--use-labels",
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers") action="store_true",
argparser.add_argument("--n-hidden", type=int, default=120, help="number of hidden units") help="Use labels in the training set as input features.",
argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate") )
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate") argparser.add_argument(
argparser.add_argument("--attn-dropout", type=float, default=0.0, help="attention drop rate") "--n-label-iters",
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate") type=int,
default=0,
help="number of label iterations",
)
argparser.add_argument(
"--no-attn-dst", action="store_true", help="Don't use attn_dst."
)
argparser.add_argument(
"--mask-rate", type=float, default=0.5, help="mask rate"
)
argparser.add_argument(
"--n-heads", type=int, default=4, help="number of heads"
)
argparser.add_argument(
"--lr", type=float, default=0.01, help="learning rate"
)
argparser.add_argument(
"--n-layers", type=int, default=3, help="number of layers"
)
argparser.add_argument(
"--n-hidden", type=int, default=120, help="number of hidden units"
)
argparser.add_argument(
"--dropout", type=float, default=0.5, help="dropout rate"
)
argparser.add_argument(
"--input-drop", type=float, default=0.1, help="input drop rate"
)
argparser.add_argument(
"--attn-dropout", type=float, default=0.0, help="attention drop rate"
)
argparser.add_argument(
"--edge-drop", type=float, default=0.1, help="edge drop rate"
)
argparser.add_argument("--wd", type=float, default=0, help="weight decay") argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=2, help="log every EVAL_EVERY epochs")
argparser.add_argument( argparser.add_argument(
"--estimation-mode", action="store_true", help="Estimate the score of test set for speed during training." "--eval-every", type=int, default=2, help="log every EVAL_EVERY epochs"
)
argparser.add_argument(
"--estimation-mode",
action="store_true",
help="Estimate the score of test set for speed during training.",
)
argparser.add_argument(
"--log-every", type=int, default=2, help="log every LOG_EVERY epochs"
)
argparser.add_argument(
"--plot-curves", action="store_true", help="plot learning curves"
) )
argparser.add_argument("--log-every", type=int, default=2, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
args = argparser.parse_args() args = argparser.parse_args()
if args.cpu: if args.cpu:
...@@ -405,14 +574,18 @@ def main(): ...@@ -405,14 +574,18 @@ def main():
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset) graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph, labels = preprocess(graph, labels, train_idx) graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)) labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)
)
# run # run
val_scores, test_scores = [], [] val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1): for i in range(1, args.n_runs + 1):
seed(args.seed + i) seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i) val_score, test_score = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i
)
val_scores.append(val_score) val_scores.append(val_score)
test_scores.append(test_score) test_scores.append(test_score)
......
import dgl import argparse
import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import dgl
import dgl.nn.pytorch as dglnn
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, num_heads, activation
n_hidden, ):
n_classes,
n_layers,
num_heads,
activation):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.GATConv((in_feats, in_feats), n_hidden, num_heads=num_heads, activation=activation)) self.layers.append(
dglnn.GATConv(
(in_feats, in_feats),
n_hidden,
num_heads=num_heads,
activation=activation,
)
)
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.GATConv((n_hidden * num_heads, n_hidden * num_heads), n_hidden, self.layers.append(
num_heads=num_heads, activation=activation)) dglnn.GATConv(
self.layers.append(dglnn.GATConv((n_hidden * num_heads, n_hidden * num_heads), n_classes, (n_hidden * num_heads, n_hidden * num_heads),
num_heads=num_heads, activation=None)) n_hidden,
num_heads=num_heads,
activation=activation,
)
)
self.layers.append(
dglnn.GATConv(
(n_hidden * num_heads, n_hidden * num_heads),
n_classes,
num_heads=num_heads,
activation=None,
)
)
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -38,7 +55,7 @@ class GAT(nn.Module): ...@@ -38,7 +55,7 @@ class GAT(nn.Module):
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.num_dst_nodes()] h_dst = h[: block.num_dst_nodes()]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
if l < self.n_layers - 1: if l < self.n_layers - 1:
...@@ -63,9 +80,19 @@ class GAT(nn.Module): ...@@ -63,9 +80,19 @@ class GAT(nn.Module):
# TODO: can we standardize this? # TODO: can we standardize this?
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
if l < self.n_layers - 1: if l < self.n_layers - 1:
y = th.zeros(g.num_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.num_nodes(),
self.n_hidden * num_heads
if l != len(self.layers) - 1
else self.n_classes,
)
else: else:
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.num_nodes(),
self.n_hidden
if l != len(self.layers) - 1
else self.n_classes,
)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
...@@ -75,15 +102,16 @@ class GAT(nn.Module): ...@@ -75,15 +102,16 @@ class GAT(nn.Module):
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].int().to(device) block = blocks[0].int().to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.num_dst_nodes()] h_dst = h[: block.num_dst_nodes()]
if l < self.n_layers - 1: if l < self.n_layers - 1:
h = layer(block, (h, h_dst)).flatten(1) h = layer(block, (h, h_dst)).flatten(1)
else: else:
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
h = h.mean(1) h = h.mean(1)
...@@ -94,12 +122,14 @@ class GAT(nn.Module): ...@@ -94,12 +122,14 @@ class GAT(nn.Module):
x = y x = y
return y.to(device) return y.to(device)
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device): def evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -114,7 +144,12 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device): ...@@ -114,7 +144,12 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device):
with th.no_grad(): with th.no_grad():
pred = model.inference(g, nfeat, num_heads, device) pred = model.inference(g, nfeat, num_heads, device)
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]), pred return (
compute_acc(pred[val_nid], labels[val_nid]),
compute_acc(pred[test_nid], labels[test_nid]),
pred,
)
def load_subtensor(nfeat, labels, seeds, input_nodes): def load_subtensor(nfeat, labels, seeds, input_nodes):
""" """
...@@ -124,14 +159,26 @@ def load_subtensor(nfeat, labels, seeds, input_nodes): ...@@ -124,14 +159,26 @@ def load_subtensor(nfeat, labels, seeds, input_nodes):
batch_labels = labels[seeds] batch_labels = labels[seeds]
return batch_inputs, batch_labels return batch_inputs, batch_labels
#### Entry point #### Entry point
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g, num_heads = data (
train_nid,
val_nid,
test_nid,
in_feats,
labels,
n_classes,
nfeat,
g,
num_heads,
) = data
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(",")]
)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
train_nid, train_nid,
...@@ -139,10 +186,13 @@ def run(args, device, data): ...@@ -139,10 +186,13 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers,
)
# Define model and optimizer # Define model and optimizer
model = GAT(in_feats, args.num_hidden, n_classes, args.num_layers, num_heads, F.relu) model = GAT(
in_feats, args.num_hidden, n_classes, args.num_layers, num_heads, F.relu
)
model = model.to(device) model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
...@@ -163,7 +213,9 @@ def run(args, device, data): ...@@ -163,7 +213,9 @@ def run(args, device, data):
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(nfeat, labels, seeds, input_nodes) batch_inputs, batch_labels = load_subtensor(
nfeat, labels, seeds, input_nodes
)
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
...@@ -175,63 +227,98 @@ def run(args, device, data): ...@@ -175,63 +227,98 @@ def run(args, device, data):
iter_tput.append(len(seeds) / (time.time() - tic_step)) iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format( th.cuda.max_memory_allocated() / 1000000
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc)) if th.cuda.is_available()
else 0
)
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB".format(
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
)
)
toc = time.time() toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic)) print("Epoch Time(s): {:.4f}".format(toc - tic))
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device) eval_acc, test_acc, pred = evaluate(
model, g, nfeat, labels, val_nid, test_nid, num_heads, device
)
if args.save_pred: if args.save_pred:
np.savetxt(args.save_pred + '%02d' % epoch, pred.argmax(1).cpu().numpy(), '%d') np.savetxt(
print('Eval Acc {:.4f}'.format(eval_acc)) args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc: if eval_acc > best_eval_acc:
best_eval_acc = eval_acc best_eval_acc = eval_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc)) print(
"Best Eval Acc {:.4f} Test Acc {:.4f}".format(
best_eval_acc, best_test_acc
)
)
print('Avg epoch time: {}'.format(avg / (epoch - 4))) print("Avg epoch time: {}".format(avg / (epoch - 4)))
return best_test_acc return best_test_acc
if __name__ == '__main__':
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument(
help="GPU device ID. Use -1 for CPU training") "--gpu",
argparser.add_argument('--num-epochs', type=int, default=100) type=int,
argparser.add_argument('--num-hidden', type=int, default=128) default=0,
argparser.add_argument('--num-layers', type=int, default=3) help="GPU device ID. Use -1 for CPU training",
argparser.add_argument('--fan-out', type=str, default='10,10,10') )
argparser.add_argument('--batch-size', type=int, default=512) argparser.add_argument("--num-epochs", type=int, default=100)
argparser.add_argument('--val-batch-size', type=int, default=512) argparser.add_argument("--num-hidden", type=int, default=128)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument('--eval-every', type=int, default=1) argparser.add_argument("--fan-out", type=str, default="10,10,10")
argparser.add_argument('--lr', type=float, default=0.001) argparser.add_argument("--batch-size", type=int, default=512)
argparser.add_argument('--num-workers', type=int, default=8, argparser.add_argument("--val-batch-size", type=int, default=512)
help="Number of sampling processes. Use 0 for no extra process.") argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument('--save-pred', type=str, default='') argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument('--head', type=int, default=4) argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument('--wd', type=float, default=0) argparser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of sampling processes. Use 0 for no extra process.",
)
argparser.add_argument("--save-pred", type=str, default="")
argparser.add_argument("--head", type=int, default=4)
argparser.add_argument("--wd", type=float, default=0)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu) device = th.device("cuda:%d" % args.gpu)
else: else:
device = th.device('cpu') device = th.device("cpu")
# load data # load data
data = DglNodePropPredDataset(name='ogbn-products') data = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test'] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
nfeat = graph.ndata.pop('feat').to(device) nfeat = graph.ndata.pop("feat").to(device)
labels = labels[:, 0].to(device) labels = labels[:, 0].to(device)
print('Total edges before adding self-loop {}'.format(graph.num_edges())) print("Total edges before adding self-loop {}".format(graph.num_edges()))
graph = graph.remove_self_loop().add_self_loop() graph = graph.remove_self_loop().add_self_loop()
print('Total edges after adding self-loop {}'.format(graph.num_edges())) print("Total edges after adding self-loop {}".format(graph.num_edges()))
in_feats = nfeat.shape[1] in_feats = nfeat.shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
...@@ -240,10 +327,22 @@ if __name__ == '__main__': ...@@ -240,10 +327,22 @@ if __name__ == '__main__':
# This avoids creating certain formats in each data loader process, which saves momory and CPU. # This avoids creating certain formats in each data loader process, which saves momory and CPU.
graph.create_formats_() graph.create_formats_()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, nfeat, graph, args.head data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
nfeat,
graph,
args.head,
)
# Run 10 times # Run 10 times
test_accs = [] test_accs = []
for i in range(10): for i in range(10):
test_accs.append(run(args, device, data).cpu().numpy()) test_accs.append(run(args, device, data).cpu().numpy())
print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs)) print(
"Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs)
)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import function as fn from dgl import function as fn
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair from dgl.utils import expand_as_pair
...@@ -31,7 +32,9 @@ class GATConv(nn.Module): ...@@ -31,7 +32,9 @@ class GATConv(nn.Module):
self._use_symmetric_norm = use_symmetric_norm self._use_symmetric_norm = use_symmetric_norm
# feat fc # feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False) self.src_fc = nn.Linear(
self._in_src_feats, out_feats * n_heads, bias=False
)
if residual: if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads) self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None self.bias = None
...@@ -42,7 +45,9 @@ class GATConv(nn.Module): ...@@ -42,7 +45,9 @@ class GATConv(nn.Module):
# attn fc # attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False) self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst: if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False) self.attn_dst_fc = nn.Linear(
self._in_src_feats, n_heads, bias=False
)
else: else:
self.attn_dst_fc = None self.attn_dst_fc = None
if edge_feats > 0: if edge_feats > 0:
...@@ -93,8 +98,12 @@ class GATConv(nn.Module): ...@@ -93,8 +98,12 @@ class GATConv(nn.Module):
norm = torch.reshape(norm, shp) norm = torch.reshape(norm, shp)
feat_src = feat_src * norm feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats) feat_src_fc = self.src_fc(feat_src).view(
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats) -1, self._n_heads, self._out_feats
)
feat_dst_fc = self.dst_fc(feat_dst).view(
-1, self._n_heads, self._out_feats
)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1) attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
...@@ -107,18 +116,24 @@ class GATConv(nn.Module): ...@@ -107,18 +116,24 @@ class GATConv(nn.Module):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src}) graph.srcdata.update(
{"feat_src_fc": feat_src_fc, "attn_src": attn_src}
)
if self.attn_dst_fc is not None: if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1) attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst}) graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node")) graph.apply_edges(
fn.u_add_v("attn_src", "attn_dst", "attn_node")
)
else: else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node")) graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"] e = graph.edata["attn_node"]
if feat_edge is not None: if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1) attn_edge = self.attn_edge_fc(feat_edge).view(
-1, self._n_heads, 1
)
graph.edata.update({"attn_edge": attn_edge}) graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"] e += graph.edata["attn_edge"]
e = self.leaky_relu(e) e = self.leaky_relu(e)
...@@ -128,12 +143,16 @@ class GATConv(nn.Module): ...@@ -128,12 +143,16 @@ class GATConv(nn.Module):
bound = int(graph.number_of_edges() * self.edge_drop) bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:] eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e) graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids)) graph.edata["a"][eids] = self.attn_drop(
edge_softmax(graph, e[eids], eids=eids)
)
else: else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc")) graph.update_all(
fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc")
)
rst = graph.dstdata["feat_src_fc"] rst = graph.dstdata["feat_src_fc"]
if self._use_symmetric_norm: if self._use_symmetric_norm:
...@@ -257,7 +276,15 @@ class GAT(nn.Module): ...@@ -257,7 +276,15 @@ class GAT(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False, self,
in_feats,
n_classes,
n_layers,
n_hidden,
activation,
dropout=0.0,
input_drop=0.0,
residual=False,
): ):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
......
import dgl import argparse
import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import dgl
import dgl.nn.pytorch as dglnn
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
n_hidden, ):
n_classes,
n_layers,
activation,
dropout):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
...@@ -37,7 +36,7 @@ class SAGE(nn.Module): ...@@ -37,7 +36,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.num_dst_nodes()] h_dst = h[: block.num_dst_nodes()]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
...@@ -60,7 +59,10 @@ class SAGE(nn.Module): ...@@ -60,7 +59,10 @@ class SAGE(nn.Module):
# on each layer are of course splitted in batches. # on each layer are of course splitted in batches.
# TODO: can we standardize this? # TODO: can we standardize this?
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes).to(device) y = th.zeros(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
).to(device)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
...@@ -70,13 +72,14 @@ class SAGE(nn.Module): ...@@ -70,13 +72,14 @@ class SAGE(nn.Module):
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].int().to(device) block = blocks[0].int().to(device)
h = x[input_nodes] h = x[input_nodes]
h_dst = h[:block.num_dst_nodes()] h_dst = h[: block.num_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
...@@ -87,12 +90,14 @@ class SAGE(nn.Module): ...@@ -87,12 +90,14 @@ class SAGE(nn.Module):
x = y x = y
return y return y
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, test_nid, device): def evaluate(model, g, nfeat, labels, val_nid, test_nid, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -106,7 +111,12 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, device): ...@@ -106,7 +111,12 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, device):
with th.no_grad(): with th.no_grad():
pred = model.inference(g, nfeat, device) pred = model.inference(g, nfeat, device)
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]), pred return (
compute_acc(pred[val_nid], labels[val_nid]),
compute_acc(pred[test_nid], labels[test_nid]),
pred,
)
def load_subtensor(nfeat, labels, seeds, input_nodes): def load_subtensor(nfeat, labels, seeds, input_nodes):
""" """
...@@ -116,6 +126,7 @@ def load_subtensor(nfeat, labels, seeds, input_nodes): ...@@ -116,6 +126,7 @@ def load_subtensor(nfeat, labels, seeds, input_nodes):
batch_labels = labels[seeds] batch_labels = labels[seeds]
return batch_inputs, batch_labels return batch_inputs, batch_labels
#### Entry point #### Entry point
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
...@@ -123,7 +134,8 @@ def run(args, device, data): ...@@ -123,7 +134,8 @@ def run(args, device, data):
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(",")]
)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
train_nid, train_nid,
...@@ -131,10 +143,18 @@ def run(args, device, data): ...@@ -131,10 +143,18 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers,
)
# Define model and optimizer # Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = SAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
...@@ -156,7 +176,9 @@ def run(args, device, data): ...@@ -156,7 +176,9 @@ def run(args, device, data):
blocks = [blk.int().to(device) for blk in blocks] blocks = [blk.int().to(device) for blk in blocks]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(nfeat, labels, seeds, input_nodes) batch_inputs, batch_labels = load_subtensor(
nfeat, labels, seeds, input_nodes
)
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
...@@ -168,58 +190,93 @@ def run(args, device, data): ...@@ -168,58 +190,93 @@ def run(args, device, data):
iter_tput.append(len(seeds) / (time.time() - tic_step)) iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format( th.cuda.max_memory_allocated() / 1000000
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc)) if th.cuda.is_available()
else 0
)
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB".format(
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
)
)
toc = time.time() toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic)) print("Epoch Time(s): {:.4f}".format(toc - tic))
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(model, g, nfeat, labels, val_nid, test_nid, device) eval_acc, test_acc, pred = evaluate(
model, g, nfeat, labels, val_nid, test_nid, device
)
if args.save_pred: if args.save_pred:
np.savetxt(args.save_pred + '%02d' % epoch, pred.argmax(1).cpu().numpy(), '%d') np.savetxt(
print('Eval Acc {:.4f}'.format(eval_acc)) args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc: if eval_acc > best_eval_acc:
best_eval_acc = eval_acc best_eval_acc = eval_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc)) print(
"Best Eval Acc {:.4f} Test Acc {:.4f}".format(
best_eval_acc, best_test_acc
)
)
print('Avg epoch time: {}'.format(avg / (epoch - 4))) print("Avg epoch time: {}".format(avg / (epoch - 4)))
return best_test_acc return best_test_acc
if __name__ == '__main__':
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument(
help="GPU device ID. Use -1 for CPU training") "--gpu",
argparser.add_argument('--num-epochs', type=int, default=20) type=int,
argparser.add_argument('--num-hidden', type=int, default=256) default=0,
argparser.add_argument('--num-layers', type=int, default=3) help="GPU device ID. Use -1 for CPU training",
argparser.add_argument('--fan-out', type=str, default='5,10,15') )
argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument('--val-batch-size', type=int, default=10000) argparser.add_argument("--num-hidden", type=int, default=256)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument('--eval-every', type=int, default=1) argparser.add_argument("--fan-out", type=str, default="5,10,15")
argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument("--batch-size", type=int, default=1000)
argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument("--val-batch-size", type=int, default=10000)
argparser.add_argument('--num-workers', type=int, default=4, argparser.add_argument("--log-every", type=int, default=20)
help="Number of sampling processes. Use 0 for no extra process.") argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument('--save-pred', type=str, default='') argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument('--wd', type=float, default=0) argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of sampling processes. Use 0 for no extra process.",
)
argparser.add_argument("--save-pred", type=str, default="")
argparser.add_argument("--wd", type=float, default=0)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu) device = th.device("cuda:%d" % args.gpu)
else: else:
device = th.device('cpu') device = th.device("cpu")
# load ogbn-products data # load ogbn-products data
data = DglNodePropPredDataset(name='ogbn-products') data = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test'] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
nfeat = graph.ndata.pop('feat').to(device) nfeat = graph.ndata.pop("feat").to(device)
labels = labels[:, 0].to(device) labels = labels[:, 0].to(device)
in_feats = nfeat.shape[1] in_feats = nfeat.shape[1]
...@@ -228,10 +285,21 @@ if __name__ == '__main__': ...@@ -228,10 +285,21 @@ if __name__ == '__main__':
# This avoids creating certain formats in each data loader process, which saves momory and CPU. # This avoids creating certain formats in each data loader process, which saves momory and CPU.
graph.create_formats_() graph.create_formats_()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, nfeat, graph data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
nfeat,
graph,
)
# Run 10 times # Run 10 times
test_accs = [] test_accs = []
for i in range(10): for i in range(10):
test_accs.append(run(args, device, data).cpu().numpy()) test_accs.append(run(args, device, data).cpu().numpy())
print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs)) print(
"Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs)
)
...@@ -7,20 +7,23 @@ import random ...@@ -7,20 +7,23 @@ import random
import time import time
from collections import OrderedDict from collections import OrderedDict
import dgl.function as fn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import MLP
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from models import MLP import dgl.function as fn
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
MultiLayerNeighborSampler,
)
epsilon = 1 - math.log(2) epsilon = 1 - math.log(2)
...@@ -44,7 +47,11 @@ def load_data(dataset): ...@@ -44,7 +47,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
graph.ndata["labels"] = labels graph.ndata["labels"] = labels
...@@ -83,7 +90,9 @@ def custom_loss_function(x, labels): ...@@ -83,7 +90,9 @@ def custom_loss_function(x, labels):
return torch.mean(y) return torch.mean(y)
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator): def train(
args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator
):
model.train() model.train()
loss_sum, total = 0, 0 loss_sum, total = 0, 0
...@@ -97,7 +106,9 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -97,7 +106,9 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
pred = model(subgraphs[0].srcdata["feat"]) pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach() preds[output_nodes] = pred.cpu().detach()
loss = criterion(pred[new_train_idx], labels[output_nodes][new_train_idx]) loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx]
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -114,7 +125,17 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -114,7 +125,17 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
@torch.no_grad() @torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator): def evaluate(
args,
model,
dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator,
):
model.eval() model.eval()
preds = torch.zeros(labels.shape[0], n_classes, device=device) preds = torch.zeros(labels.shape[0], n_classes, device=device)
...@@ -144,43 +165,56 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit ...@@ -144,43 +165,56 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit
) )
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): def run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
evaluator_wrapper = lambda pred, labels: evaluator.eval( evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"] )["acc"]
criterion = custom_loss_function criterion = custom_loss_function
train_batch_size = 4096 train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors train_sampler = MultiLayerNeighborSampler(
[0 for _ in range(args.n_layers)]
) # no not sample neighbors
train_dataloader = DataLoader( train_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_size=train_batch_size, batch_size=train_batch_size,
shuffle=True, shuffle=True,
num_workers=4 num_workers=4,
) )
eval_batch_size = 4096 eval_batch_size = 4096
eval_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors eval_sampler = MultiLayerNeighborSampler(
[0 for _ in range(args.n_layers)]
) # no not sample neighbors
if args.eval_last: if args.eval_last:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()])
else: else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
batch_size=eval_batch_size, batch_size=eval_batch_size,
shuffle=False, shuffle=False,
num_workers=4 num_workers=4,
) )
model = gen_model(args).to(device) model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4 optimizer,
mode="max",
factor=0.7,
patience=20,
verbose=True,
min_lr=1e-4,
) )
best_model_state_dict = None best_model_state_dict = None
...@@ -193,21 +227,47 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -193,21 +227,47 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for epoch in range(1, args.n_epochs + 1): for epoch in range(1, args.n_epochs + 1):
tic = time.time() tic = time.time()
loss, score = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) loss, score = train(
args,
model,
train_dataloader,
labels,
train_idx,
criterion,
optimizer,
evaluator_wrapper,
)
toc = time.time() toc = time.time()
total_time += toc - tic total_time += toc - tic
if epoch % args.eval_every == 0 or epoch % args.log_every == 0: if epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate( (
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper train_score,
val_score,
test_score,
train_loss,
val_loss,
test_loss,
) = evaluate(
args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator_wrapper,
) )
if val_score > best_val_score: if val_score > best_val_score:
best_val_score = val_score best_val_score = val_score
final_test_score = test_score final_test_score = test_score
if args.eval_last: if args.eval_last:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()} best_model_state_dict = {
k: v.to("cpu") for k, v in model.state_dict().items()
}
best_model_state_dict = OrderedDict(best_model_state_dict) best_model_state_dict = OrderedDict(best_model_state_dict)
if epoch % args.log_every == 0: if epoch % args.log_every == 0:
...@@ -221,8 +281,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -221,8 +281,26 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
) )
for l, e in zip( for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses], [
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss], scores,
train_scores,
val_scores,
test_scores,
losses,
train_losses,
val_losses,
test_losses,
],
[
score,
train_score,
val_score,
test_score,
loss,
train_loss,
val_loss,
test_loss,
],
): ):
l.append(e) l.append(e)
...@@ -231,19 +309,29 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -231,19 +309,29 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.eval_last: if args.eval_last:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
batch_size=eval_batch_size, batch_size=eval_batch_size,
shuffle=False, shuffle=False,
num_workers=4 num_workers=4,
) )
final_test_score = evaluate( final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator_wrapper,
)[2] )[2]
print("*" * 50) print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test score: {final_test_score}") print(
f"Average epoch time: {total_time / args.n_epochs}, Test score: {final_test_score}"
)
if args.plot_curves: if args.plot_curves:
fig = plt.figure(figsize=(24, 24)) fig = plt.figure(figsize=(24, 24))
...@@ -251,8 +339,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -251,8 +339,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101)) ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]): for y, label in zip(
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) [train_scores, val_scores, test_scores],
["train score", "val score", "test score"],
):
plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(20)) ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_major_locator(MultipleLocator(0.01))
...@@ -268,9 +364,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -268,9 +364,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip( for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] [losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
): ):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(20)) ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_major_locator(MultipleLocator(0.1))
...@@ -286,14 +388,23 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -286,14 +388,23 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
def count_parameters(args): def count_parameters(args):
model = gen_model(args) model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) return sum(
[np.prod(p.size()) for p in model.parameters() if p.requires_grad]
)
def main(): def main():
global device global device
argparser = argparse.ArgumentParser("GAT on OGBN-Proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter) argparser = argparse.ArgumentParser(
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.") "GAT on OGBN-Proteins",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides '--gpu'.",
)
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--seed", type=int, help="seed", default=0) argparser.add_argument("--seed", type=int, help="seed", default=0)
argparser.add_argument("--n-runs", type=int, default=10) argparser.add_argument("--n-runs", type=int, default=10)
...@@ -304,8 +415,16 @@ def main(): ...@@ -304,8 +415,16 @@ def main():
argparser.add_argument("--dropout", type=float, default=0.2) argparser.add_argument("--dropout", type=float, default=0.2)
argparser.add_argument("--input-drop", type=float, default=0) argparser.add_argument("--input-drop", type=float, default=0)
argparser.add_argument("--wd", type=float, default=0) argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--estimation-mode", action="store_true", help="Estimate the score of test set for speed.") argparser.add_argument(
argparser.add_argument("--eval-last", action="store_true", help="Evaluate the score of test set at last.") "--estimation-mode",
action="store_true",
help="Estimate the score of test set for speed.",
)
argparser.add_argument(
"--eval-last",
action="store_true",
help="Evaluate the score of test set at last.",
)
argparser.add_argument("--eval-every", type=int, default=1) argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument("--log-every", type=int, default=1) argparser.add_argument("--log-every", type=int, default=1)
argparser.add_argument("--plot-curves", action="store_true") argparser.add_argument("--plot-curves", action="store_true")
...@@ -317,7 +436,9 @@ def main(): ...@@ -317,7 +436,9 @@ def main():
device = torch.device("cuda:%d" % args.gpu) device = torch.device("cuda:%d" % args.gpu)
if args.estimation_mode: if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.") print(
"WARNING: Estimation mode is enabled. The test score is not accurate."
)
seed(args.seed) seed(args.seed)
...@@ -336,7 +457,9 @@ def main(): ...@@ -336,7 +457,9 @@ def main():
val_scores, test_scores = [], [] val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1): for i in range(1, args.n_runs + 1):
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i) val_score, test_score = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i
)
val_scores.append(val_score) val_scores.append(val_score)
test_scores.append(test_score) test_scores.append(test_score)
...@@ -349,7 +472,9 @@ def main(): ...@@ -349,7 +472,9 @@ def main():
print(f"Number of params: {count_parameters(args)}") print(f"Number of params: {count_parameters(args)}")
if args.estimation_mode: if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.") print(
"WARNING: Estimation mode is enabled. The test score is not accurate."
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,15 @@ import torch.nn.functional as F ...@@ -5,7 +5,15 @@ import torch.nn.functional as F
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False, self,
in_feats,
n_classes,
n_layers,
n_hidden,
activation,
dropout=0.0,
input_drop=0.0,
residual=False,
): ):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment