"docs/vscode:/vscode.git/clone" did not exist on "4507bebc1abd5e87071d6268c073c1afe74bdd12"
Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
......@@ -2,14 +2,14 @@ import argparse
import copy
import os
import dgl
import torch
import torch.nn.functional as F
import torch.optim as optim
from model import MLP, CorrectAndSmooth, MLPLinear
from model import CorrectAndSmooth, MLP, MLPLinear
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl
def evaluate(y_pred, y_true, idx, evaluator):
return evaluator.eval({"y_true": y_true[idx], "y_pred": y_pred[idx]})["acc"]
......@@ -104,7 +104,6 @@ def main():
# training
print("---------- Training ----------")
for i in range(args.epochs):
model.train()
opt.zero_grad()
......
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class MLPLinear(nn.Module):
def __init__(self, in_dim, out_dim):
......
import argparse
import dgl.function as fn
import numpy as np
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from torch.nn import functional as F, Parameter
from tqdm import trange
from utils import evaluate, generate_random_seeds, set_random_state
import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class DAGNNConv(nn.Module):
def __init__(self, in_dim, k):
......@@ -26,7 +25,6 @@ class DAGNNConv(nn.Module):
nn.init.xavier_uniform_(self.s, gain=gain)
def forward(self, graph, feats):
with graph.local_scope():
results = [feats]
......@@ -68,7 +66,6 @@ class MLPLayer(nn.Module):
nn.init.zeros_(self.linear.bias)
def forward(self, feats):
feats = self.dropout(feats)
feats = self.linear(feats)
if self.activation:
......
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import MLP, MessageNorm
from ogb.graphproppred.mol_encoder import BondEncoder
import dgl.function as fn
from dgl.nn.functional import edge_softmax
from modules import MessageNorm, MLP
from ogb.graphproppred.mol_encoder import BondEncoder
class GENConv(nn.Module):
......
......@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from models import DeeperGCN
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator, collate_dgl
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from torch.utils.data import DataLoader
......
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.glob import AvgPooling
from layers import GENConv
from ogb.graphproppred.mol_encoder import AtomEncoder
import dgl.function as fn
from dgl.nn.pytorch.glob import AvgPooling
class DeeperGCN(nn.Module):
r"""
......
import argparse, time
import numpy as np
import dgl
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgi import Classifier, DGI
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgi import DGI, Classifier
from dgl.data import load_data, register_data_args
def evaluate(model, features, labels, mask):
model.eval()
......@@ -19,20 +21,21 @@ def evaluate(model, features, labels, mask):
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
g = data[0]
features = torch.FloatTensor(g.ndata['feat'])
labels = torch.LongTensor(g.ndata['label'])
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(g.ndata['train_mask'])
val_mask = torch.BoolTensor(g.ndata['val_mask'])
test_mask = torch.BoolTensor(g.ndata['test_mask'])
features = torch.FloatTensor(g.ndata["feat"])
labels = torch.LongTensor(g.ndata["label"])
if hasattr(torch, "BoolTensor"):
train_mask = torch.BoolTensor(g.ndata["train_mask"])
val_mask = torch.BoolTensor(g.ndata["val_mask"])
test_mask = torch.BoolTensor(g.ndata["test_mask"])
else:
train_mask = torch.ByteTensor(g.ndata['train_mask'])
val_mask = torch.ByteTensor(g.ndata['val_mask'])
test_mask = torch.ByteTensor(g.ndata['test_mask'])
train_mask = torch.ByteTensor(g.ndata["train_mask"])
val_mask = torch.ByteTensor(g.ndata["val_mask"])
test_mask = torch.ByteTensor(g.ndata["test_mask"])
in_feats = features.shape[1]
n_classes = data.num_classes
n_edges = g.number_of_edges()
......@@ -57,19 +60,21 @@ def main(args):
if args.gpu >= 0:
g = g.to(args.gpu)
# create DGI model
dgi = DGI(g,
in_feats,
args.n_hidden,
args.n_layers,
nn.PReLU(args.n_hidden),
args.dropout)
dgi = DGI(
g,
in_feats,
args.n_hidden,
args.n_layers,
nn.PReLU(args.n_hidden),
args.dropout,
)
if cuda:
dgi.cuda()
dgi_optimizer = torch.optim.Adam(dgi.parameters(),
lr=args.dgi_lr,
weight_decay=args.weight_decay)
dgi_optimizer = torch.optim.Adam(
dgi.parameters(), lr=args.dgi_lr, weight_decay=args.weight_decay
)
# train deep graph infomax
cnt_wait = 0
......@@ -90,33 +95,38 @@ def main(args):
best = loss
best_t = epoch
cnt_wait = 0
torch.save(dgi.state_dict(), 'best_dgi.pkl')
torch.save(dgi.state_dict(), "best_dgi.pkl")
else:
cnt_wait += 1
if cnt_wait == args.patience:
print('Early stopping!')
print("Early stopping!")
break
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000
)
)
# create classifier model
classifier = Classifier(args.n_hidden, n_classes)
if cuda:
classifier.cuda()
classifier_optimizer = torch.optim.Adam(classifier.parameters(),
lr=args.classifier_lr,
weight_decay=args.weight_decay)
classifier_optimizer = torch.optim.Adam(
classifier.parameters(),
lr=args.classifier_lr,
weight_decay=args.weight_decay,
)
# train classifier
print('Loading {}th epoch'.format(best_t))
dgi.load_state_dict(torch.load('best_dgi.pkl'))
print("Loading {}th epoch".format(best_t))
dgi.load_state_dict(torch.load("best_dgi.pkl"))
embeds = dgi.encoder(features, corrupt=False)
embeds = embeds.detach()
dur = []
......@@ -135,39 +145,67 @@ def main(args):
dur.append(time.time() - t0)
acc = evaluate(classifier, embeds, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print()
acc = evaluate(classifier, embeds, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGI')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DGI")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--dgi-lr", type=float, default=1e-3,
help="dgi learning rate")
parser.add_argument("--classifier-lr", type=float, default=1e-2,
help="classifier learning rate")
parser.add_argument("--n-dgi-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-classifier-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=512,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=0.,
help="Weight for L2 loss")
parser.add_argument("--patience", type=int, default=20,
help="early stop patience condition")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument(
"--dropout", type=float, default=0.0, help="dropout probability"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument(
"--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
)
parser.add_argument(
"--classifier-lr",
type=float,
default=1e-2,
help="classifier learning rate",
)
parser.add_argument(
"--n-dgi-epochs",
type=int,
default=300,
help="number of training epochs",
)
parser.add_argument(
"--n-classifier-epochs",
type=int,
default=300,
help="number of training epochs",
)
parser.add_argument(
"--n-hidden", type=int, default=512, help="number of hidden gcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
)
parser.add_argument(
"--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
)
parser.add_argument(
"--patience", type=int, default=20, help="early stop patience condition"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)
......
......@@ -4,7 +4,6 @@ and will be loaded when setting up."""
def dataset_based_configure(opts):
if opts["dataset"] == "cycles":
ds_configure = cycles_configure
else:
......
......@@ -65,7 +65,6 @@ def main(opts):
optimizer.zero_grad()
for i, data in enumerate(data_loader):
log_prob = model(actions=data)
prob = log_prob.detach().exp()
......
from functools import partial
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.distributions import Bernoulli, Categorical
......@@ -15,20 +16,19 @@ class GraphEmbed(nn.Module):
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
nn.Linear(node_hidden_size, 1), nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)
def forward(self, g):
if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
hvs = g.ndata["hv"]
return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(
0, keepdim=True
)
class GraphProp(nn.Module):
......@@ -46,41 +46,45 @@ class GraphProp(nn.Module):
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size))
message_funcs.append(
nn.Linear(
2 * node_hidden_size + 1, self.node_activation_hidden_size
)
)
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)
)
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
return {"m": torch.cat([edges.src["hv"], edges.data["he"]], dim=1)}
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
hv_old = nodes.data["hv"]
m = nodes.mailbox["m"]
message = torch.cat(
[hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2
)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
return {"a": node_activation}
def forward(self, g):
if g.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
g.update_all(
message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]
)
g.ndata["hv"] = self.node_update_funcs[t](
g.ndata["a"], g.ndata["hv"]
)
def bernoulli_action_log_prob(logit, action):
......@@ -96,33 +100,39 @@ class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.graph_op = {"embed": graph_embed_func}
self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.initialize_hv = nn.Linear(
node_hidden_size + graph_embed_func.graph_hidden_size,
node_hidden_size,
)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed):
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
torch.cat(
[
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed,
],
dim=1,
)
)
g.nodes[num_nodes - 1].data["hv"] = hv_init
g.nodes[num_nodes - 1].data["a"] = self.init_node_activation
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
graph_embed = self.graph_op["embed"](g)
logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit)
......@@ -146,19 +156,19 @@ class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1)
self.graph_op = {"embed": graph_embed_func}
self.add_edge = nn.Linear(
graph_embed_func.graph_hidden_size + node_hidden_size, 1
)
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
graph_embed = self.graph_op["embed"](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data["hv"]
logit = self.add_edge(torch.cat(
[graph_embed, src_embed], dim=1))
logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit)
if not self.training:
......@@ -176,7 +186,7 @@ class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func}
self.graph_op = {"prop": graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list):
......@@ -184,7 +194,7 @@ class ChooseDestAndUpdate(nn.Module):
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr
g.edges[src_list, dest_list].data["he"] = edge_repr
def prepare_training(self):
self.log_prob = []
......@@ -193,12 +203,12 @@ class ChooseDestAndUpdate(nn.Module):
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
src_embed_expand = g.nodes[src].data["hv"].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data["hv"]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1)
torch.cat([possible_dests_embed, src_embed_expand], dim=1)
).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if not self.training:
......@@ -213,17 +223,17 @@ class ChooseDestAndUpdate(nn.Module):
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g)
self.graph_op["prop"](g)
if self.training:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]
)
class DGMG(nn.Module):
def __init__(self, v_max, node_hidden_size,
num_prop_rounds):
def __init__(self, v_max, node_hidden_size, num_prop_rounds):
super(DGMG, self).__init__()
# Graph configuration
......@@ -233,22 +243,20 @@ class DGMG(nn.Module):
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)
# Actions
self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
self.graph_prop, node_hidden_size
)
# Weight initialization
self.init_weights()
def init_weights(self):
from utils import weights_init, dgmg_message_weight_init
from utils import dgmg_message_weight_init, weights_init
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
......@@ -290,9 +298,11 @@ class DGMG(nn.Module):
self.choose_dest_agent(self.g, a)
def get_log_prob(self):
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
return (
torch.cat(self.add_node_agent.log_prob).sum()
+ torch.cat(self.add_edge_agent.log_prob).sum()
+ torch.cat(self.choose_dest_agent.log_prob).sum()
)
def forward_train(self, actions):
self.prepare_for_train()
......
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag
import dgl.function as fn
from model.loss import EntropyLoss
from ..model_utils import masked_softmax
from .aggregator import MaxPoolAggregator, MeanAggregator, LSTMAggregator
from .aggregator import LSTMAggregator, MaxPoolAggregator, MeanAggregator
from .bundler import Bundler
from ..model_utils import masked_softmax
from model.loss import EntropyLoss
class GraphSageLayer(nn.Module):
......@@ -18,17 +18,27 @@ class GraphSageLayer(nn.Module):
Here, graphsage layer is a reduced function in DGL framework
"""
def __init__(self, in_feats, out_feats, activation, dropout,
aggregator_type, bn=False, bias=True):
def __init__(
self,
in_feats,
out_feats,
activation,
dropout,
aggregator_type,
bn=False,
bias=True,
):
super(GraphSageLayer, self).__init__()
self.use_bn = bn
self.bundler = Bundler(in_feats, out_feats, activation, dropout,
bias=bias)
self.bundler = Bundler(
in_feats, out_feats, activation, dropout, bias=bias
)
self.dropout = nn.Dropout(p=dropout)
if aggregator_type == "maxpool":
self.aggregator = MaxPoolAggregator(in_feats, in_feats,
activation, bias)
self.aggregator = MaxPoolAggregator(
in_feats, in_feats, activation, bias
)
elif aggregator_type == "lstm":
self.aggregator = LSTMAggregator(in_feats, in_feats)
else:
......@@ -36,15 +46,14 @@ class GraphSageLayer(nn.Module):
def forward(self, g, h):
h = self.dropout(h)
g.ndata['h'] = h
if self.use_bn and not hasattr(self, 'bn'):
g.ndata["h"] = h
if self.use_bn and not hasattr(self, "bn"):
device = h.device
self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
g.update_all(fn.copy_u(u='h', out='m'), self.aggregator,
self.bundler)
g.update_all(fn.copy_u(u="h", out="m"), self.aggregator, self.bundler)
if self.use_bn:
h = self.bn(h)
h = g.ndata.pop('h')
h = g.ndata.pop("h")
return h
......@@ -53,21 +62,36 @@ class GraphSage(nn.Module):
Grahpsage network that concatenate several graphsage layer
"""
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation,
dropout, aggregator_type):
def __init__(
self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type,
):
super(GraphSage, self).__init__()
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout,
aggregator_type))
self.layers.append(
GraphSageLayer(
in_feats, n_hidden, activation, dropout, aggregator_type
)
)
# hidden layers
for _ in range(n_layers - 1):
self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation,
dropout, aggregator_type))
self.layers.append(
GraphSageLayer(
n_hidden, n_hidden, activation, dropout, aggregator_type
)
)
# output layer
self.layers.append(GraphSageLayer(n_hidden, n_classes, None,
dropout, aggregator_type))
self.layers.append(
GraphSageLayer(n_hidden, n_classes, None, dropout, aggregator_type)
)
def forward(self, g, features):
h = features
......@@ -77,37 +101,44 @@ class GraphSage(nn.Module):
class DiffPoolBatchedGraphLayer(nn.Module):
def __init__(self, input_dim, assign_dim, output_feat_dim,
activation, dropout, aggregator_type, link_pred):
def __init__(
self,
input_dim,
assign_dim,
output_feat_dim,
activation,
dropout,
aggregator_type,
link_pred,
):
super(DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim
self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim
self.link_pred = link_pred
self.feat_gc = GraphSageLayer(
input_dim,
output_feat_dim,
activation,
dropout,
aggregator_type)
input_dim, output_feat_dim, activation, dropout, aggregator_type
)
self.pool_gc = GraphSageLayer(
input_dim,
assign_dim,
activation,
dropout,
aggregator_type)
input_dim, assign_dim, activation, dropout, aggregator_type
)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
self.reg_loss.append(EntropyLoss())
def forward(self, g, h):
feat = self.feat_gc(g, h) # size = (sum_N, F_out), sum_N is num of nodes in this batch
feat = self.feat_gc(
g, h
) # size = (sum_N, F_out), sum_N is num of nodes in this batch
device = feat.device
assign_tensor = self.pool_gc(g, h) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor = self.pool_gc(
g, h
) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor = F.softmax(assign_tensor, dim=1)
assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a)
assign_tensor = torch.block_diag(
*assign_tensor
) # size = (sum_N, batch_size * N_a)
h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(transpose=True, ctx=device)
......@@ -115,9 +146,10 @@ class DiffPoolBatchedGraphLayer(nn.Module):
adj_new = torch.mm(torch.t(assign_tensor), adj_new)
if self.link_pred:
current_lp_loss = torch.norm(adj.to_dense() -
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
self.loss_log['LinkPredLoss'] = current_lp_loss
current_lp_loss = torch.norm(
adj.to_dense() - torch.mm(assign_tensor, torch.t(assign_tensor))
) / np.power(g.number_of_nodes(), 2)
self.loss_log["LinkPredLoss"] = current_lp_loss
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
......
import time
import dgl
import numpy as np
import torch
import torch.nn as nn
......@@ -7,8 +9,6 @@ import torch.nn.functional as F
from scipy.linalg import block_diag
from torch.nn import init
import dgl
from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
from .model_utils import batch2tensor
from .tensorized_layers import *
......@@ -91,7 +91,6 @@ class DiffPool(nn.Module):
# and return pool_embedding_dim node embedding
pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
else:
pool_embedding_dim = embedding_dim
self.first_diffpool_layer = DiffPoolBatchedGraphLayer(
......
import torch
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from model.tensorized_layers.graphsage import BatchedGraphSAGE
class DiffPoolAssignment(nn.Module):
def __init__(self, nfeat, nnext):
......
import torch
from torch import nn as nn
from model.loss import EntropyLoss, LinkPredLoss
from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn
class BatchedDiffPool(nn.Module):
......
......@@ -3,6 +3,9 @@ import os
import random
import time
import dgl
import dgl.function as fn
import networkx as nx
import numpy as np
import torch
......@@ -10,12 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from data_utils import pre_process
from model.encoder import DiffPool
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu
from model.encoder import DiffPool
global_train_time_per_epoch = []
......@@ -261,8 +261,8 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
total = 0
print("\nEPOCH ###### {} ######".format(epoch))
computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for key, value in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available():
......@@ -341,7 +341,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
correct_label = 0
with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
for key, value in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available():
......
......@@ -2,11 +2,14 @@ import copy
from pathlib import Path
import click
import dgl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dgl.data.utils import Subset
from logzero import logger
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
......@@ -16,9 +19,6 @@ from ruamel.yaml import YAML
from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import Subset
def split_dataset(
dataset, num_train, num_valid, shuffle=False, random_state=None
......
import dgl.function as fn
import torch
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl.function as fn
class InteractionBlock(nn.Module):
def __init__(
......
import dgl
import dgl.function as fn
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl
import dgl.function as fn
class InteractionPPBlock(nn.Module):
def __init__(
......
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl
import dgl.function as fn
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class OutputBlock(nn.Module):
......
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl
import dgl.function as fn
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class OutputPPBlock(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment