Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -2,14 +2,14 @@ import argparse ...@@ -2,14 +2,14 @@ import argparse
import copy import copy
import os import os
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from model import MLP, CorrectAndSmooth, MLPLinear from model import CorrectAndSmooth, MLP, MLPLinear
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl
def evaluate(y_pred, y_true, idx, evaluator): def evaluate(y_pred, y_true, idx, evaluator):
return evaluator.eval({"y_true": y_true[idx], "y_pred": y_pred[idx]})["acc"] return evaluator.eval({"y_true": y_true[idx], "y_pred": y_pred[idx]})["acc"]
...@@ -104,7 +104,6 @@ def main(): ...@@ -104,7 +104,6 @@ def main():
# training # training
print("---------- Training ----------") print("---------- Training ----------")
for i in range(args.epochs): for i in range(args.epochs):
model.train() model.train()
opt.zero_grad() opt.zero_grad()
......
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
class MLPLinear(nn.Module): class MLPLinear(nn.Module):
def __init__(self, in_dim, out_dim): def __init__(self, in_dim, out_dim):
......
import argparse import argparse
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import functional as F, Parameter
from torch.nn import functional as F
from tqdm import trange from tqdm import trange
from utils import evaluate, generate_random_seeds, set_random_state from utils import evaluate, generate_random_seeds, set_random_state
import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class DAGNNConv(nn.Module): class DAGNNConv(nn.Module):
def __init__(self, in_dim, k): def __init__(self, in_dim, k):
...@@ -26,7 +25,6 @@ class DAGNNConv(nn.Module): ...@@ -26,7 +25,6 @@ class DAGNNConv(nn.Module):
nn.init.xavier_uniform_(self.s, gain=gain) nn.init.xavier_uniform_(self.s, gain=gain)
def forward(self, graph, feats): def forward(self, graph, feats):
with graph.local_scope(): with graph.local_scope():
results = [feats] results = [feats]
...@@ -68,7 +66,6 @@ class MLPLayer(nn.Module): ...@@ -68,7 +66,6 @@ class MLPLayer(nn.Module):
nn.init.zeros_(self.linear.bias) nn.init.zeros_(self.linear.bias)
def forward(self, feats): def forward(self, feats):
feats = self.dropout(feats) feats = self.dropout(feats)
feats = self.linear(feats) feats = self.linear(feats)
if self.activation: if self.activation:
......
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from modules import MLP, MessageNorm
from ogb.graphproppred.mol_encoder import BondEncoder
import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
from modules import MessageNorm, MLP
from ogb.graphproppred.mol_encoder import BondEncoder
class GENConv(nn.Module): class GENConv(nn.Module):
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from models import DeeperGCN from models import DeeperGCN
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator, collate_dgl from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
import dgl.function as fn
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch.glob import AvgPooling
from layers import GENConv from layers import GENConv
from ogb.graphproppred.mol_encoder import AtomEncoder from ogb.graphproppred.mol_encoder import AtomEncoder
import dgl.function as fn
from dgl.nn.pytorch.glob import AvgPooling
class DeeperGCN(nn.Module): class DeeperGCN(nn.Module):
r""" r"""
......
import argparse, time import argparse, time
import numpy as np
import dgl
import networkx as nx import networkx as nx
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl from dgi import Classifier, DGI
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import load_data, register_data_args
from dgi import DGI, Classifier
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
...@@ -19,20 +21,21 @@ def evaluate(model, features, labels, mask): ...@@ -19,20 +21,21 @@ def evaluate(model, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
g = data[0] g = data[0]
features = torch.FloatTensor(g.ndata['feat']) features = torch.FloatTensor(g.ndata["feat"])
labels = torch.LongTensor(g.ndata['label']) labels = torch.LongTensor(g.ndata["label"])
if hasattr(torch, 'BoolTensor'): if hasattr(torch, "BoolTensor"):
train_mask = torch.BoolTensor(g.ndata['train_mask']) train_mask = torch.BoolTensor(g.ndata["train_mask"])
val_mask = torch.BoolTensor(g.ndata['val_mask']) val_mask = torch.BoolTensor(g.ndata["val_mask"])
test_mask = torch.BoolTensor(g.ndata['test_mask']) test_mask = torch.BoolTensor(g.ndata["test_mask"])
else: else:
train_mask = torch.ByteTensor(g.ndata['train_mask']) train_mask = torch.ByteTensor(g.ndata["train_mask"])
val_mask = torch.ByteTensor(g.ndata['val_mask']) val_mask = torch.ByteTensor(g.ndata["val_mask"])
test_mask = torch.ByteTensor(g.ndata['test_mask']) test_mask = torch.ByteTensor(g.ndata["test_mask"])
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_classes n_classes = data.num_classes
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
...@@ -57,19 +60,21 @@ def main(args): ...@@ -57,19 +60,21 @@ def main(args):
if args.gpu >= 0: if args.gpu >= 0:
g = g.to(args.gpu) g = g.to(args.gpu)
# create DGI model # create DGI model
dgi = DGI(g, dgi = DGI(
in_feats, g,
args.n_hidden, in_feats,
args.n_layers, args.n_hidden,
nn.PReLU(args.n_hidden), args.n_layers,
args.dropout) nn.PReLU(args.n_hidden),
args.dropout,
)
if cuda: if cuda:
dgi.cuda() dgi.cuda()
dgi_optimizer = torch.optim.Adam(dgi.parameters(), dgi_optimizer = torch.optim.Adam(
lr=args.dgi_lr, dgi.parameters(), lr=args.dgi_lr, weight_decay=args.weight_decay
weight_decay=args.weight_decay) )
# train deep graph infomax # train deep graph infomax
cnt_wait = 0 cnt_wait = 0
...@@ -90,33 +95,38 @@ def main(args): ...@@ -90,33 +95,38 @@ def main(args):
best = loss best = loss
best_t = epoch best_t = epoch
cnt_wait = 0 cnt_wait = 0
torch.save(dgi.state_dict(), 'best_dgi.pkl') torch.save(dgi.state_dict(), "best_dgi.pkl")
else: else:
cnt_wait += 1 cnt_wait += 1
if cnt_wait == args.patience: if cnt_wait == args.patience:
print('Early stopping!') print("Early stopping!")
break break
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | " print(
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000
)
)
# create classifier model # create classifier model
classifier = Classifier(args.n_hidden, n_classes) classifier = Classifier(args.n_hidden, n_classes)
if cuda: if cuda:
classifier.cuda() classifier.cuda()
classifier_optimizer = torch.optim.Adam(classifier.parameters(), classifier_optimizer = torch.optim.Adam(
lr=args.classifier_lr, classifier.parameters(),
weight_decay=args.weight_decay) lr=args.classifier_lr,
weight_decay=args.weight_decay,
)
# train classifier # train classifier
print('Loading {}th epoch'.format(best_t)) print("Loading {}th epoch".format(best_t))
dgi.load_state_dict(torch.load('best_dgi.pkl')) dgi.load_state_dict(torch.load("best_dgi.pkl"))
embeds = dgi.encoder(features, corrupt=False) embeds = dgi.encoder(features, corrupt=False)
embeds = embeds.detach() embeds = embeds.detach()
dur = [] dur = []
...@@ -135,39 +145,67 @@ def main(args): ...@@ -135,39 +145,67 @@ def main(args):
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(classifier, embeds, labels, val_mask) acc = evaluate(classifier, embeds, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(classifier, embeds, labels, test_mask) acc = evaluate(classifier, embeds, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGI') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DGI")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0., parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.0, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--dgi-lr", type=float, default=1e-3, parser.add_argument(
help="dgi learning rate") "--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
parser.add_argument("--classifier-lr", type=float, default=1e-2, )
help="classifier learning rate") parser.add_argument(
parser.add_argument("--n-dgi-epochs", type=int, default=300, "--classifier-lr",
help="number of training epochs") type=float,
parser.add_argument("--n-classifier-epochs", type=int, default=300, default=1e-2,
help="number of training epochs") help="classifier learning rate",
parser.add_argument("--n-hidden", type=int, default=512, )
help="number of hidden gcn units") parser.add_argument(
parser.add_argument("--n-layers", type=int, default=1, "--n-dgi-epochs",
help="number of hidden gcn layers") type=int,
parser.add_argument("--weight-decay", type=float, default=0., default=300,
help="Weight for L2 loss") help="number of training epochs",
parser.add_argument("--patience", type=int, default=20, )
help="early stop patience condition") parser.add_argument(
parser.add_argument("--self-loop", action='store_true', "--n-classifier-epochs",
help="graph self-loop (default=False)") type=int,
default=300,
help="number of training epochs",
)
parser.add_argument(
"--n-hidden", type=int, default=512, help="number of hidden gcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
)
parser.add_argument(
"--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
)
parser.add_argument(
"--patience", type=int, default=20, help="early stop patience condition"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False) parser.set_defaults(self_loop=False)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -4,7 +4,6 @@ and will be loaded when setting up.""" ...@@ -4,7 +4,6 @@ and will be loaded when setting up."""
def dataset_based_configure(opts): def dataset_based_configure(opts):
if opts["dataset"] == "cycles": if opts["dataset"] == "cycles":
ds_configure = cycles_configure ds_configure = cycles_configure
else: else:
......
...@@ -65,7 +65,6 @@ def main(opts): ...@@ -65,7 +65,6 @@ def main(opts):
optimizer.zero_grad() optimizer.zero_grad()
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
log_prob = model(actions=data) log_prob = model(actions=data)
prob = log_prob.detach().exp() prob = log_prob.detach().exp()
......
from functools import partial
import dgl import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from torch.distributions import Bernoulli, Categorical from torch.distributions import Bernoulli, Categorical
...@@ -15,20 +16,19 @@ class GraphEmbed(nn.Module): ...@@ -15,20 +16,19 @@ class GraphEmbed(nn.Module):
# Embed graphs # Embed graphs
self.node_gating = nn.Sequential( self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1), nn.Linear(node_hidden_size, 1), nn.Sigmoid()
nn.Sigmoid()
) )
self.node_to_graph = nn.Linear(node_hidden_size, self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)
self.graph_hidden_size)
def forward(self, g): def forward(self, g):
if g.number_of_nodes() == 0: if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size) return torch.zeros(1, self.graph_hidden_size)
else: else:
# Node features are stored as hv in ndata. # Node features are stored as hv in ndata.
hvs = g.ndata['hv'] hvs = g.ndata["hv"]
return (self.node_gating(hvs) * return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(
self.node_to_graph(hvs)).sum(0, keepdim=True) 0, keepdim=True
)
class GraphProp(nn.Module): class GraphProp(nn.Module):
...@@ -46,41 +46,45 @@ class GraphProp(nn.Module): ...@@ -46,41 +46,45 @@ class GraphProp(nn.Module):
for t in range(num_prop_rounds): for t in range(num_prop_rounds):
# input being [hv, hu, xuv] # input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1, message_funcs.append(
self.node_activation_hidden_size)) nn.Linear(
2 * node_hidden_size + 1, self.node_activation_hidden_size
)
)
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t)) self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append( node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size, nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)
node_hidden_size)) )
self.message_funcs = nn.ModuleList(message_funcs) self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs) self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges): def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])""" """For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'], return {"m": torch.cat([edges.src["hv"], edges.data["he"]], dim=1)}
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round): def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv'] hv_old = nodes.data["hv"]
m = nodes.mailbox['m'] m = nodes.mailbox["m"]
message = torch.cat([ message = torch.cat(
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2) [hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2
)
node_activation = (self.message_funcs[round](message)).sum(1) node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation} return {"a": node_activation}
def forward(self, g): def forward(self, g):
if g.number_of_edges() == 0: if g.number_of_edges() == 0:
return return
else: else:
for t in range(self.num_prop_rounds): for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg, g.update_all(
reduce_func=self.reduce_funcs[t]) message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]
g.ndata['hv'] = self.node_update_funcs[t]( )
g.ndata['a'], g.ndata['hv']) g.ndata["hv"] = self.node_update_funcs[t](
g.ndata["a"], g.ndata["hv"]
)
def bernoulli_action_log_prob(logit, action): def bernoulli_action_log_prob(logit, action):
...@@ -96,33 +100,39 @@ class AddNode(nn.Module): ...@@ -96,33 +100,39 @@ class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__() super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {"embed": graph_embed_func}
self.stop = 1 self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1) self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv # If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size) self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \ self.initialize_hv = nn.Linear(
graph_embed_func.graph_hidden_size, node_hidden_size + graph_embed_func.graph_hidden_size,
node_hidden_size) node_hidden_size,
)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size) self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed): def _initialize_node_repr(self, g, node_type, graph_embed):
num_nodes = g.number_of_nodes() num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv( hv_init = self.initialize_hv(
torch.cat([ torch.cat(
self.node_type_embed(torch.LongTensor([node_type])), [
graph_embed], dim=1)) self.node_type_embed(torch.LongTensor([node_type])),
g.nodes[num_nodes - 1].data['hv'] = hv_init graph_embed,
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation ],
dim=1,
)
)
g.nodes[num_nodes - 1].data["hv"] = hv_init
g.nodes[num_nodes - 1].data["a"] = self.init_node_activation
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op["embed"](g)
logit = self.add_node(graph_embed) logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
...@@ -146,19 +156,19 @@ class AddEdge(nn.Module): ...@@ -146,19 +156,19 @@ class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__() super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {"embed": graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \ self.add_edge = nn.Linear(
node_hidden_size, 1) graph_embed_func.graph_hidden_size + node_hidden_size, 1
)
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op["embed"](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv'] src_embed = g.nodes[g.number_of_nodes() - 1].data["hv"]
logit = self.add_edge(torch.cat( logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))
[graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
if not self.training: if not self.training:
...@@ -176,7 +186,7 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -176,7 +186,7 @@ class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size): def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__() super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func} self.graph_op = {"prop": graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1) self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list): def _initialize_edge_repr(self, g, src_list, dest_list):
...@@ -184,7 +194,7 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -184,7 +194,7 @@ class ChooseDestAndUpdate(nn.Module):
# For multiple edge types, we can use a one hot representation # For multiple edge types, we can use a one hot representation
# or an embedding module. # or an embedding module.
edge_repr = torch.ones(len(src_list), 1) edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr g.edges[src_list, dest_list].data["he"] = edge_repr
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
...@@ -193,12 +203,12 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -193,12 +203,12 @@ class ChooseDestAndUpdate(nn.Module):
src = g.number_of_nodes() - 1 src = g.number_of_nodes() - 1
possible_dests = range(src) possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1) src_embed_expand = g.nodes[src].data["hv"].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv'] possible_dests_embed = g.nodes[possible_dests].data["hv"]
dests_scores = self.choose_dest( dests_scores = self.choose_dest(
torch.cat([possible_dests_embed, torch.cat([possible_dests_embed, src_embed_expand], dim=1)
src_embed_expand], dim=1)).view(1, -1) ).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1) dests_probs = F.softmax(dests_scores, dim=1)
if not self.training: if not self.training:
...@@ -213,17 +223,17 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -213,17 +223,17 @@ class ChooseDestAndUpdate(nn.Module):
g.add_edges(src_list, dest_list) g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list) self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g) self.graph_op["prop"](g)
if self.training: if self.training:
if dests_probs.nelement() > 1: if dests_probs.nelement() > 1:
self.log_prob.append( self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1]) F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]
)
class DGMG(nn.Module): class DGMG(nn.Module):
def __init__(self, v_max, node_hidden_size, def __init__(self, v_max, node_hidden_size, num_prop_rounds):
num_prop_rounds):
super(DGMG, self).__init__() super(DGMG, self).__init__()
# Graph configuration # Graph configuration
...@@ -233,22 +243,20 @@ class DGMG(nn.Module): ...@@ -233,22 +243,20 @@ class DGMG(nn.Module):
self.graph_embed = GraphEmbed(node_hidden_size) self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module # Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds, self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)
node_hidden_size)
# Actions # Actions
self.add_node_agent = AddNode( self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)
self.graph_embed, node_hidden_size) self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate( self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size) self.graph_prop, node_hidden_size
)
# Weight initialization # Weight initialization
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
from utils import weights_init, dgmg_message_weight_init from utils import dgmg_message_weight_init, weights_init
self.graph_embed.apply(weights_init) self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init) self.graph_prop.apply(weights_init)
...@@ -290,9 +298,11 @@ class DGMG(nn.Module): ...@@ -290,9 +298,11 @@ class DGMG(nn.Module):
self.choose_dest_agent(self.g, a) self.choose_dest_agent(self.g, a)
def get_log_prob(self): def get_log_prob(self):
return torch.cat(self.add_node_agent.log_prob).sum()\ return (
+ torch.cat(self.add_edge_agent.log_prob).sum()\ torch.cat(self.add_node_agent.log_prob).sum()
+ torch.cat(self.choose_dest_agent.log_prob).sum() + torch.cat(self.add_edge_agent.log_prob).sum()
+ torch.cat(self.choose_dest_agent.log_prob).sum()
)
def forward_train(self, actions): def forward_train(self, actions):
self.prepare_for_train() self.prepare_for_train()
......
import dgl.function as fn
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag from scipy.linalg import block_diag
import dgl.function as fn from model.loss import EntropyLoss
from ..model_utils import masked_softmax
from .aggregator import MaxPoolAggregator, MeanAggregator, LSTMAggregator from .aggregator import LSTMAggregator, MaxPoolAggregator, MeanAggregator
from .bundler import Bundler from .bundler import Bundler
from ..model_utils import masked_softmax
from model.loss import EntropyLoss
class GraphSageLayer(nn.Module): class GraphSageLayer(nn.Module):
...@@ -18,17 +18,27 @@ class GraphSageLayer(nn.Module): ...@@ -18,17 +18,27 @@ class GraphSageLayer(nn.Module):
Here, graphsage layer is a reduced function in DGL framework Here, graphsage layer is a reduced function in DGL framework
""" """
def __init__(self, in_feats, out_feats, activation, dropout, def __init__(
aggregator_type, bn=False, bias=True): self,
in_feats,
out_feats,
activation,
dropout,
aggregator_type,
bn=False,
bias=True,
):
super(GraphSageLayer, self).__init__() super(GraphSageLayer, self).__init__()
self.use_bn = bn self.use_bn = bn
self.bundler = Bundler(in_feats, out_feats, activation, dropout, self.bundler = Bundler(
bias=bias) in_feats, out_feats, activation, dropout, bias=bias
)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
if aggregator_type == "maxpool": if aggregator_type == "maxpool":
self.aggregator = MaxPoolAggregator(in_feats, in_feats, self.aggregator = MaxPoolAggregator(
activation, bias) in_feats, in_feats, activation, bias
)
elif aggregator_type == "lstm": elif aggregator_type == "lstm":
self.aggregator = LSTMAggregator(in_feats, in_feats) self.aggregator = LSTMAggregator(in_feats, in_feats)
else: else:
...@@ -36,15 +46,14 @@ class GraphSageLayer(nn.Module): ...@@ -36,15 +46,14 @@ class GraphSageLayer(nn.Module):
def forward(self, g, h): def forward(self, g, h):
h = self.dropout(h) h = self.dropout(h)
g.ndata['h'] = h g.ndata["h"] = h
if self.use_bn and not hasattr(self, 'bn'): if self.use_bn and not hasattr(self, "bn"):
device = h.device device = h.device
self.bn = nn.BatchNorm1d(h.size()[1]).to(device) self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
g.update_all(fn.copy_u(u='h', out='m'), self.aggregator, g.update_all(fn.copy_u(u="h", out="m"), self.aggregator, self.bundler)
self.bundler)
if self.use_bn: if self.use_bn:
h = self.bn(h) h = self.bn(h)
h = g.ndata.pop('h') h = g.ndata.pop("h")
return h return h
...@@ -53,21 +62,36 @@ class GraphSage(nn.Module): ...@@ -53,21 +62,36 @@ class GraphSage(nn.Module):
Grahpsage network that concatenate several graphsage layer Grahpsage network that concatenate several graphsage layer
""" """
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, def __init__(
dropout, aggregator_type): self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type,
):
super(GraphSage, self).__init__() super(GraphSage, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# input layer # input layer
self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout, self.layers.append(
aggregator_type)) GraphSageLayer(
in_feats, n_hidden, activation, dropout, aggregator_type
)
)
# hidden layers # hidden layers
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation, self.layers.append(
dropout, aggregator_type)) GraphSageLayer(
n_hidden, n_hidden, activation, dropout, aggregator_type
)
)
# output layer # output layer
self.layers.append(GraphSageLayer(n_hidden, n_classes, None, self.layers.append(
dropout, aggregator_type)) GraphSageLayer(n_hidden, n_classes, None, dropout, aggregator_type)
)
def forward(self, g, features): def forward(self, g, features):
h = features h = features
...@@ -77,37 +101,44 @@ class GraphSage(nn.Module): ...@@ -77,37 +101,44 @@ class GraphSage(nn.Module):
class DiffPoolBatchedGraphLayer(nn.Module): class DiffPoolBatchedGraphLayer(nn.Module):
def __init__(
def __init__(self, input_dim, assign_dim, output_feat_dim, self,
activation, dropout, aggregator_type, link_pred): input_dim,
assign_dim,
output_feat_dim,
activation,
dropout,
aggregator_type,
link_pred,
):
super(DiffPoolBatchedGraphLayer, self).__init__() super(DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim self.embedding_dim = input_dim
self.assign_dim = assign_dim self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim self.hidden_dim = output_feat_dim
self.link_pred = link_pred self.link_pred = link_pred
self.feat_gc = GraphSageLayer( self.feat_gc = GraphSageLayer(
input_dim, input_dim, output_feat_dim, activation, dropout, aggregator_type
output_feat_dim, )
activation,
dropout,
aggregator_type)
self.pool_gc = GraphSageLayer( self.pool_gc = GraphSageLayer(
input_dim, input_dim, assign_dim, activation, dropout, aggregator_type
assign_dim, )
activation,
dropout,
aggregator_type)
self.reg_loss = nn.ModuleList([]) self.reg_loss = nn.ModuleList([])
self.loss_log = {} self.loss_log = {}
self.reg_loss.append(EntropyLoss()) self.reg_loss.append(EntropyLoss())
def forward(self, g, h): def forward(self, g, h):
feat = self.feat_gc(g, h) # size = (sum_N, F_out), sum_N is num of nodes in this batch feat = self.feat_gc(
g, h
) # size = (sum_N, F_out), sum_N is num of nodes in this batch
device = feat.device device = feat.device
assign_tensor = self.pool_gc(g, h) # size = (sum_N, N_a), N_a is num of nodes in pooled graph. assign_tensor = self.pool_gc(
g, h
) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor = F.softmax(assign_tensor, dim=1) assign_tensor = F.softmax(assign_tensor, dim=1)
assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist()) assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a) assign_tensor = torch.block_diag(
*assign_tensor
) # size = (sum_N, batch_size * N_a)
h = torch.matmul(torch.t(assign_tensor), feat) h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(transpose=True, ctx=device) adj = g.adjacency_matrix(transpose=True, ctx=device)
...@@ -115,9 +146,10 @@ class DiffPoolBatchedGraphLayer(nn.Module): ...@@ -115,9 +146,10 @@ class DiffPoolBatchedGraphLayer(nn.Module):
adj_new = torch.mm(torch.t(assign_tensor), adj_new) adj_new = torch.mm(torch.t(assign_tensor), adj_new)
if self.link_pred: if self.link_pred:
current_lp_loss = torch.norm(adj.to_dense() - current_lp_loss = torch.norm(
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2) adj.to_dense() - torch.mm(assign_tensor, torch.t(assign_tensor))
self.loss_log['LinkPredLoss'] = current_lp_loss ) / np.power(g.number_of_nodes(), 2)
self.loss_log["LinkPredLoss"] = current_lp_loss
for loss_layer in self.reg_loss: for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__) loss_name = str(type(loss_layer).__name__)
......
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -7,8 +9,6 @@ import torch.nn.functional as F ...@@ -7,8 +9,6 @@ import torch.nn.functional as F
from scipy.linalg import block_diag from scipy.linalg import block_diag
from torch.nn import init from torch.nn import init
import dgl
from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
from .model_utils import batch2tensor from .model_utils import batch2tensor
from .tensorized_layers import * from .tensorized_layers import *
...@@ -91,7 +91,6 @@ class DiffPool(nn.Module): ...@@ -91,7 +91,6 @@ class DiffPool(nn.Module):
# and return pool_embedding_dim node embedding # and return pool_embedding_dim node embedding
pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
else: else:
pool_embedding_dim = embedding_dim pool_embedding_dim = embedding_dim
self.first_diffpool_layer = DiffPoolBatchedGraphLayer( self.first_diffpool_layer = DiffPoolBatchedGraphLayer(
......
import torch import torch
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn from torch import nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn import functional as F from torch.nn import functional as F
from model.tensorized_layers.graphsage import BatchedGraphSAGE
class DiffPoolAssignment(nn.Module): class DiffPoolAssignment(nn.Module):
def __init__(self, nfeat, nnext): def __init__(self, nfeat, nnext):
......
import torch import torch
from torch import nn as nn
from model.loss import EntropyLoss, LinkPredLoss from model.loss import EntropyLoss, LinkPredLoss
from model.tensorized_layers.assignment import DiffPoolAssignment from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn
class BatchedDiffPool(nn.Module): class BatchedDiffPool(nn.Module):
......
...@@ -3,6 +3,9 @@ import os ...@@ -3,6 +3,9 @@ import os
import random import random
import time import time
import dgl
import dgl.function as fn
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import torch import torch
...@@ -10,12 +13,9 @@ import torch.nn as nn ...@@ -10,12 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from data_utils import pre_process from data_utils import pre_process
from model.encoder import DiffPool
import dgl
import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import tu from dgl.data import tu
from model.encoder import DiffPool
global_train_time_per_epoch = [] global_train_time_per_epoch = []
...@@ -261,8 +261,8 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -261,8 +261,8 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
total = 0 total = 0
print("\nEPOCH ###### {} ######".format(epoch)) print("\nEPOCH ###### {} ######".format(epoch))
computation_time = 0.0 computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader): for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items(): for key, value in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float() batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long() graph_labels = graph_labels.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -341,7 +341,7 @@ def evaluate(dataloader, model, prog_args, logger=None): ...@@ -341,7 +341,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
correct_label = 0 correct_label = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items(): for key, value in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float() batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long() graph_labels = graph_labels.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -2,11 +2,14 @@ import copy ...@@ -2,11 +2,14 @@ import copy
from pathlib import Path from pathlib import Path
import click import click
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.data.utils import Subset
from logzero import logger from logzero import logger
from modules.dimenet import DimeNet from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP from modules.dimenet_pp import DimeNetPP
...@@ -16,9 +19,6 @@ from ruamel.yaml import YAML ...@@ -16,9 +19,6 @@ from ruamel.yaml import YAML
from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import Subset
def split_dataset( def split_dataset(
dataset, num_train, num_valid, shuffle=False, random_state=None dataset, num_train, num_valid, shuffle=False, random_state=None
......
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer from modules.residual_layer import ResidualLayer
import dgl.function as fn
class InteractionBlock(nn.Module): class InteractionBlock(nn.Module):
def __init__( def __init__(
......
import dgl
import dgl.function as fn
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer from modules.residual_layer import ResidualLayer
import dgl
import dgl.function as fn
class InteractionPPBlock(nn.Module): class InteractionPPBlock(nn.Module):
def __init__( def __init__(
......
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class OutputBlock(nn.Module): class OutputBlock(nn.Module):
......
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class OutputPPBlock(nn.Module): class OutputPPBlock(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment