Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
from .gnn import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer from .gnn import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
\ No newline at end of file
...@@ -15,7 +15,7 @@ class Aggregator(nn.Module): ...@@ -15,7 +15,7 @@ class Aggregator(nn.Module):
super(Aggregator, self).__init__() super(Aggregator, self).__init__()
def forward(self, node): def forward(self, node):
neighbour = node.mailbox['m'] neighbour = node.mailbox["m"]
c = self.aggre(neighbour) c = self.aggre(neighbour)
return {"c": c} return {"c": c}
...@@ -25,9 +25,9 @@ class Aggregator(nn.Module): ...@@ -25,9 +25,9 @@ class Aggregator(nn.Module):
class MeanAggregator(Aggregator): class MeanAggregator(Aggregator):
''' """
Mean Aggregator for graphsage Mean Aggregator for graphsage
''' """
def __init__(self): def __init__(self):
super(MeanAggregator, self).__init__() super(MeanAggregator, self).__init__()
...@@ -38,17 +38,18 @@ class MeanAggregator(Aggregator): ...@@ -38,17 +38,18 @@ class MeanAggregator(Aggregator):
class MaxPoolAggregator(Aggregator): class MaxPoolAggregator(Aggregator):
''' """
Maxpooling aggregator for graphsage Maxpooling aggregator for graphsage
''' """
def __init__(self, in_feats, out_feats, activation, bias): def __init__(self, in_feats, out_feats, activation, bias):
super(MaxPoolAggregator, self).__init__() super(MaxPoolAggregator, self).__init__()
self.linear = nn.Linear(in_feats, out_feats, bias=bias) self.linear = nn.Linear(in_feats, out_feats, bias=bias)
self.activation = activation self.activation = activation
# Xavier initialization of weight # Xavier initialization of weight
nn.init.xavier_uniform_(self.linear.weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.linear.weight, gain=nn.init.calculate_gain("relu")
)
def aggre(self, neighbour): def aggre(self, neighbour):
neighbour = self.linear(neighbour) neighbour = self.linear(neighbour)
...@@ -59,9 +60,9 @@ class MaxPoolAggregator(Aggregator): ...@@ -59,9 +60,9 @@ class MaxPoolAggregator(Aggregator):
class LSTMAggregator(Aggregator): class LSTMAggregator(Aggregator):
''' """
LSTM aggregator for graphsage LSTM aggregator for graphsage
''' """
def __init__(self, in_feats, hidden_feats): def __init__(self, in_feats, hidden_feats):
super(LSTMAggregator, self).__init__() super(LSTMAggregator, self).__init__()
...@@ -69,31 +70,33 @@ class LSTMAggregator(Aggregator): ...@@ -69,31 +70,33 @@ class LSTMAggregator(Aggregator):
self.hidden_dim = hidden_feats self.hidden_dim = hidden_feats
self.hidden = self.init_hidden() self.hidden = self.init_hidden()
nn.init.xavier_uniform_(self.lstm.weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.lstm.weight, gain=nn.init.calculate_gain("relu")
)
def init_hidden(self): def init_hidden(self):
""" """
Defaulted to initialite all zero Defaulted to initialite all zero
""" """
return (torch.zeros(1, 1, self.hidden_dim), return (
torch.zeros(1, 1, self.hidden_dim)) torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim),
)
def aggre(self, neighbours): def aggre(self, neighbours):
''' """
aggregation function aggregation function
''' """
# N X F # N X F
rand_order = torch.randperm(neighbours.size()[1]) rand_order = torch.randperm(neighbours.size()[1])
neighbours = neighbours[:, rand_order, :] neighbours = neighbours[:, rand_order, :]
(lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0], (lstm_out, self.hidden) = self.lstm(
neighbours.size()[ neighbours.view(neighbours.size()[0], neighbours.size()[1], -1)
1], )
-1))
return lstm_out[:, -1, :] return lstm_out[:, -1, :]
def forward(self, node): def forward(self, node):
neighbour = node.mailbox['m'] neighbour = node.mailbox["m"]
c = self.aggre(neighbour) c = self.aggre(neighbour)
return {"c": c} return {"c": c}
...@@ -14,8 +14,9 @@ class Bundler(nn.Module): ...@@ -14,8 +14,9 @@ class Bundler(nn.Module):
self.linear = nn.Linear(in_feats * 2, out_feats, bias) self.linear = nn.Linear(in_feats * 2, out_feats, bias)
self.activation = activation self.activation = activation
nn.init.xavier_uniform_(self.linear.weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.linear.weight, gain=nn.init.calculate_gain("relu")
)
def concat(self, h, aggre_result): def concat(self, h, aggre_result):
bundle = torch.cat((h, aggre_result), 1) bundle = torch.cat((h, aggre_result), 1)
...@@ -23,8 +24,8 @@ class Bundler(nn.Module): ...@@ -23,8 +24,8 @@ class Bundler(nn.Module):
return bundle return bundle
def forward(self, node): def forward(self, node):
h = node.data['h'] h = node.data["h"]
c = node.data['c'] c = node.data["c"]
bundle = self.concat(h, c) bundle = self.concat(h, c)
bundle = F.normalize(bundle, p=2, dim=1) bundle = F.normalize(bundle, p=2, dim=1)
if self.activation: if self.activation:
......
import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag from scipy.linalg import block_diag
from torch.nn import init
import dgl import dgl
from .dgl_layers import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
from .tensorized_layers import *
from .model_utils import batch2tensor from .model_utils import batch2tensor
import time from .tensorized_layers import *
class DiffPool(nn.Module): class DiffPool(nn.Module):
...@@ -19,10 +19,23 @@ class DiffPool(nn.Module): ...@@ -19,10 +19,23 @@ class DiffPool(nn.Module):
DiffPool Fuse DiffPool Fuse
""" """
def __init__(self, input_dim, hidden_dim, embedding_dim, def __init__(
label_dim, activation, n_layers, dropout, self,
n_pooling, linkpred, batch_size, aggregator_type, input_dim,
assign_dim, pool_ratio, cat=False): hidden_dim,
embedding_dim,
label_dim,
activation,
n_layers,
dropout,
n_pooling,
linkpred,
batch_size,
aggregator_type,
assign_dim,
pool_ratio,
cat=False,
):
super(DiffPool, self).__init__() super(DiffPool, self).__init__()
self.link_pred = linkpred self.link_pred = linkpred
self.concat = cat self.concat = cat
...@@ -51,7 +64,9 @@ class DiffPool(nn.Module): ...@@ -51,7 +64,9 @@ class DiffPool(nn.Module):
activation, activation,
dropout, dropout,
aggregator_type, aggregator_type,
self.bn)) self.bn,
)
)
for _ in range(n_layers - 2): for _ in range(n_layers - 2):
self.gc_before_pool.append( self.gc_before_pool.append(
GraphSageLayer( GraphSageLayer(
...@@ -60,14 +75,14 @@ class DiffPool(nn.Module): ...@@ -60,14 +75,14 @@ class DiffPool(nn.Module):
activation, activation,
dropout, dropout,
aggregator_type, aggregator_type,
self.bn)) self.bn,
)
)
self.gc_before_pool.append( self.gc_before_pool.append(
GraphSageLayer( GraphSageLayer(
hidden_dim, hidden_dim, embedding_dim, None, dropout, aggregator_type
embedding_dim, )
None, )
dropout,
aggregator_type))
assign_dims = [] assign_dims = []
assign_dims.append(self.assign_dim) assign_dims.append(self.assign_dim)
...@@ -86,7 +101,8 @@ class DiffPool(nn.Module): ...@@ -86,7 +101,8 @@ class DiffPool(nn.Module):
activation, activation,
dropout, dropout,
aggregator_type, aggregator_type,
self.link_pred) self.link_pred,
)
gc_after_per_pool = nn.ModuleList() gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
...@@ -102,23 +118,26 @@ class DiffPool(nn.Module): ...@@ -102,23 +118,26 @@ class DiffPool(nn.Module):
pool_embedding_dim, pool_embedding_dim,
self.assign_dim, self.assign_dim,
hidden_dim, hidden_dim,
self.link_pred)) self.link_pred,
)
)
gc_after_per_pool = nn.ModuleList() gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
gc_after_per_pool.append( gc_after_per_pool.append(
BatchedGraphSAGE( BatchedGraphSAGE(hidden_dim, hidden_dim)
hidden_dim, hidden_dim)) )
gc_after_per_pool.append( gc_after_per_pool.append(
BatchedGraphSAGE( BatchedGraphSAGE(hidden_dim, embedding_dim)
hidden_dim, embedding_dim)) )
self.gc_after_pool.append(gc_after_per_pool) self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim) assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio) self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer # predicting layer
if self.concat: if self.concat:
self.pred_input_dim = pool_embedding_dim * \ self.pred_input_dim = (
self.num_aggs * (n_pooling + 1) pool_embedding_dim * self.num_aggs * (n_pooling + 1)
)
else: else:
self.pred_input_dim = embedding_dim * self.num_aggs self.pred_input_dim = embedding_dim * self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, label_dim) self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
...@@ -126,8 +145,9 @@ class DiffPool(nn.Module): ...@@ -126,8 +145,9 @@ class DiffPool(nn.Module):
# weight initialization # weight initialization
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
m.weight.data = init.xavier_uniform_(m.weight.data, m.weight.data = init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) m.weight.data, gain=nn.init.calculate_gain("relu")
)
if m.bias is not None: if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0) m.bias.data = init.constant_(m.bias.data, 0.0)
...@@ -161,7 +181,7 @@ class DiffPool(nn.Module): ...@@ -161,7 +181,7 @@ class DiffPool(nn.Module):
def forward(self, g): def forward(self, g):
self.link_pred_loss = [] self.link_pred_loss = []
self.entropy_loss = [] self.entropy_loss = []
h = g.ndata['feat'] h = g.ndata["feat"]
# node feature for assignment matrix computation is the same as the # node feature for assignment matrix computation is the same as the
# original node feature # original node feature
h_a = h h_a = h
...@@ -171,12 +191,12 @@ class DiffPool(nn.Module): ...@@ -171,12 +191,12 @@ class DiffPool(nn.Module):
# we use GCN blocks to get an embedding first # we use GCN blocks to get an embedding first
g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat) g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat)
g.ndata['h'] = g_embedding g.ndata["h"] = g_embedding
readout = dgl.sum_nodes(g, 'h') readout = dgl.sum_nodes(g, "h")
out_all.append(readout) out_all.append(readout)
if self.num_aggs == 2: if self.num_aggs == 2:
readout = dgl.max_nodes(g, 'h') readout = dgl.max_nodes(g, "h")
out_all.append(readout) out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding) adj, h = self.first_diffpool_layer(g, g_embedding)
...@@ -184,7 +204,8 @@ class DiffPool(nn.Module): ...@@ -184,7 +204,8 @@ class DiffPool(nn.Module):
h, adj = batch2tensor(adj, h, node_per_pool_graph) h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized( h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[0], self.concat) h, adj, self.gc_after_pool[0], self.concat
)
readout = torch.sum(h, dim=1) readout = torch.sum(h, dim=1)
out_all.append(readout) out_all.append(readout)
if self.num_aggs == 2: if self.num_aggs == 2:
...@@ -194,7 +215,8 @@ class DiffPool(nn.Module): ...@@ -194,7 +215,8 @@ class DiffPool(nn.Module):
for i, diffpool_layer in enumerate(self.diffpool_layers): for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj) h, adj = diffpool_layer(h, adj)
h = self.gcn_forward_tensorized( h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[i + 1], self.concat) h, adj, self.gc_after_pool[i + 1], self.concat
)
readout = torch.sum(h, dim=1) readout = torch.sum(h, dim=1)
out_all.append(readout) out_all.append(readout)
if self.num_aggs == 2: if self.num_aggs == 2:
...@@ -208,10 +230,10 @@ class DiffPool(nn.Module): ...@@ -208,10 +230,10 @@ class DiffPool(nn.Module):
return ypred return ypred
def loss(self, pred, label): def loss(self, pred, label):
''' """
loss function loss function
''' """
#softmax + CE # softmax + CE
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
loss = criterion(pred, label) loss = criterion(pred, label)
for key, value in self.first_diffpool_layer.loss_log.items(): for key, value in self.first_diffpool_layer.loss_log.items():
......
...@@ -5,16 +5,19 @@ import torch.nn as nn ...@@ -5,16 +5,19 @@ import torch.nn as nn
class EntropyLoss(nn.Module): class EntropyLoss(nn.Module):
# Return Scalar # Return Scalar
def forward(self, adj, anext, s_l): def forward(self, adj, anext, s_l):
entropy = (torch.distributions.Categorical( entropy = (
probs=s_l).entropy()).sum(-1).mean(-1) (torch.distributions.Categorical(probs=s_l).entropy())
.sum(-1)
.mean(-1)
)
assert not torch.isnan(entropy) assert not torch.isnan(entropy)
return entropy return entropy
class LinkPredLoss(nn.Module): class LinkPredLoss(nn.Module):
def forward(self, adj, anext, s_l): def forward(self, adj, anext, s_l):
link_pred_loss = ( link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(
adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) dim=(1, 2)
)
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean() return link_pred_loss.mean()
...@@ -22,12 +22,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph): ...@@ -22,12 +22,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
return feat, adj return feat, adj
def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, def masked_softmax(
mask_fill_value=-1e32): matrix, mask, dim=-1, memory_efficient=True, mask_fill_value=-1e32
''' ):
"""
masked_softmax for dgl batch graph masked_softmax for dgl batch graph
code snippet contributed by AllenNLP (https://github.com/allenai/allennlp) code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
''' """
if mask is None: if mask is None:
result = th.nn.functional.softmax(matrix, dim=dim) result = th.nn.functional.softmax(matrix, dim=dim)
else: else:
...@@ -39,7 +40,8 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, ...@@ -39,7 +40,8 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
result = result * mask result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else: else:
masked_matrix = matrix.masked_fill((1 - mask).byte(), masked_matrix = matrix.masked_fill(
mask_fill_value) (1 - mask).byte(), mask_fill_value
)
result = th.nn.functional.softmax(masked_matrix, dim=dim) result = th.nn.functional.softmax(masked_matrix, dim=dim)
return result return result
from .diffpool import BatchedDiffPool from .diffpool import BatchedDiffPool
from .graphsage import BatchedGraphSAGE from .graphsage import BatchedGraphSAGE
\ No newline at end of file
import torch import torch
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn import functional as F
from model.tensorized_layers.graphsage import BatchedGraphSAGE
class DiffPoolAssignment(nn.Module): class DiffPoolAssignment(nn.Module):
......
import torch import torch
from torch import nn as nn from model.loss import EntropyLoss, LinkPredLoss
from model.tensorized_layers.assignment import DiffPoolAssignment from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE from model.tensorized_layers.graphsage import BatchedGraphSAGE
from model.loss import EntropyLoss, LinkPredLoss from torch import nn as nn
class BatchedDiffPool(nn.Module): class BatchedDiffPool(nn.Module):
...@@ -25,7 +24,7 @@ class BatchedDiffPool(nn.Module): ...@@ -25,7 +24,7 @@ class BatchedDiffPool(nn.Module):
z_l = self.embed(x, adj) z_l = self.embed(x, adj)
s_l = self.assign(x, adj) s_l = self.assign(x, adj)
if log: if log:
self.log['s'] = s_l.cpu().numpy() self.log["s"] = s_l.cpu().numpy()
xnext = torch.matmul(s_l.transpose(-1, -2), z_l) xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l) anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
...@@ -33,5 +32,5 @@ class BatchedDiffPool(nn.Module): ...@@ -33,5 +32,5 @@ class BatchedDiffPool(nn.Module):
loss_name = str(type(loss_layer).__name__) loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, anext, s_l) self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
if log: if log:
self.log['a'] = anext.cpu().numpy() self.log["a"] = anext.cpu().numpy()
return xnext, anext return xnext, anext
...@@ -4,8 +4,9 @@ from torch.nn import functional as F ...@@ -4,8 +4,9 @@ from torch.nn import functional as F
class BatchedGraphSAGE(nn.Module): class BatchedGraphSAGE(nn.Module):
def __init__(self, infeat, outfeat, use_bn=True, def __init__(
mean=False, add_self=False): self, infeat, outfeat, use_bn=True, mean=False, add_self=False
):
super().__init__() super().__init__()
self.add_self = add_self self.add_self = add_self
self.use_bn = use_bn self.use_bn = use_bn
...@@ -13,12 +14,12 @@ class BatchedGraphSAGE(nn.Module): ...@@ -13,12 +14,12 @@ class BatchedGraphSAGE(nn.Module):
self.W = nn.Linear(infeat, outfeat, bias=True) self.W = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_( nn.init.xavier_uniform_(
self.W.weight, self.W.weight, gain=nn.init.calculate_gain("relu")
gain=nn.init.calculate_gain('relu')) )
def forward(self, x, adj): def forward(self, x, adj):
num_node_per_graph = adj.size(1) num_node_per_graph = adj.size(1)
if self.use_bn and not hasattr(self, 'bn'): if self.use_bn and not hasattr(self, "bn"):
self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device) self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)
if self.add_self: if self.add_self:
...@@ -37,6 +38,6 @@ class BatchedGraphSAGE(nn.Module): ...@@ -37,6 +38,6 @@ class BatchedGraphSAGE(nn.Module):
def __repr__(self): def __repr__(self):
if self.use_bn: if self.use_bn:
return 'BN' + super(BatchedGraphSAGE, self).__repr__() return "BN" + super(BatchedGraphSAGE, self).__repr__()
else: else:
return super(BatchedGraphSAGE, self).__repr__() return super(BatchedGraphSAGE, self).__repr__()
import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse import argparse
import os
import random import random
import time import time
import networkx as nx
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from data_utils import pre_process
from model.encoder import DiffPool
import dgl
import dgl.function as fn import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import tu from dgl.data import tu
from model.encoder import DiffPool
from data_utils import pre_process
global_train_time_per_epoch = [] global_train_time_per_epoch = []
def arg_parse(): def arg_parse():
''' """
argument parser argument parser
''' """
parser = argparse.ArgumentParser(description='DiffPool arguments') parser = argparse.ArgumentParser(description="DiffPool arguments")
parser.add_argument('--dataset', dest='dataset', help='Input Dataset') parser.add_argument("--dataset", dest="dataset", help="Input Dataset")
parser.add_argument(
"--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio"
)
parser.add_argument(
"--num_pool", dest="num_pool", type=int, help="num_pooling layer"
)
parser.add_argument(
"--no_link_pred",
dest="linkpred",
action="store_false",
help="switch of link prediction object",
)
parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda")
parser.add_argument("--lr", dest="lr", type=float, help="learning rate")
parser.add_argument(
"--clip", dest="clip", type=float, help="gradient clipping"
)
parser.add_argument(
"--batch-size", dest="batch_size", type=int, help="batch size"
)
parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch")
parser.add_argument(
"--train-ratio",
dest="train_ratio",
type=float,
help="ratio of trainning dataset split",
)
parser.add_argument( parser.add_argument(
'--pool_ratio', "--test-ratio",
dest='pool_ratio', dest="test_ratio",
type=float, type=float,
help='pooling ratio') help="ratio of testing dataset split",
)
parser.add_argument( parser.add_argument(
'--num_pool', "--num_workers",
dest='num_pool', dest="n_worker",
type=int, type=int,
help='num_pooling layer') help="number of workers when dataloading",
parser.add_argument('--no_link_pred', dest='linkpred', action='store_false', )
help='switch of link prediction object')
parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
parser.add_argument( parser.add_argument(
'--clip', "--gc-per-block",
dest='clip', dest="gc_per_block",
type=float, type=int,
help='gradient clipping') help="number of graph conv layer per block",
)
parser.add_argument(
"--bn",
dest="bn",
action="store_const",
const=True,
default=True,
help="switch for bn",
)
parser.add_argument(
"--dropout", dest="dropout", type=float, help="dropout rate"
)
parser.add_argument( parser.add_argument(
'--batch-size', "--bias",
dest='batch_size', dest="bias",
action="store_const",
const=True,
default=True,
help="switch for bias",
)
parser.add_argument(
"--save_dir",
dest="save_dir",
help="model saving directory: SAVE_DICT/DATASET",
)
parser.add_argument(
"--load_epoch",
dest="load_epoch",
type=int, type=int,
help='batch size') help="load trained model params from\
parser.add_argument('--epochs', dest='epoch', type=int, SAVE_DICT/DATASET/model-LOAD_EPOCH",
help='num-of-epoch') )
parser.add_argument('--train-ratio', dest='train_ratio', type=float,
help='ratio of trainning dataset split')
parser.add_argument('--test-ratio', dest='test_ratio', type=float,
help='ratio of testing dataset split')
parser.add_argument('--num_workers', dest='n_worker', type=int,
help='number of workers when dataloading')
parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
help='number of graph conv layer per block')
parser.add_argument('--bn', dest='bn', action='store_const', const=True,
default=True, help='switch for bn')
parser.add_argument('--dropout', dest='dropout', type=float,
help='dropout rate')
parser.add_argument('--bias', dest='bias', action='store_const',
const=True, default=True, help='switch for bias')
parser.add_argument( parser.add_argument(
'--save_dir', "--data_mode",
dest='save_dir', dest="data_mode",
help='model saving directory: SAVE_DICT/DATASET') help="data\
parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\ preprocessing mode: default, id, degree, or one-hot\
vector of degree number', choices=['default', 'id', 'deg', vector of degree number",
'deg_num']) choices=["default", "id", "deg", "deg_num"],
)
parser.set_defaults(dataset='ENZYMES',
pool_ratio=0.15, parser.set_defaults(
num_pool=1, dataset="ENZYMES",
cuda=1, pool_ratio=0.15,
lr=1e-3, num_pool=1,
clip=2.0, cuda=1,
batch_size=20, lr=1e-3,
epoch=4000, clip=2.0,
train_ratio=0.7, batch_size=20,
test_ratio=0.1, epoch=4000,
n_worker=1, train_ratio=0.7,
gc_per_block=3, test_ratio=0.1,
dropout=0.0, n_worker=1,
method='diffpool', gc_per_block=3,
bn=True, dropout=0.0,
bias=True, method="diffpool",
save_dir="./model_param", bn=True,
load_epoch=-1, bias=True,
data_mode='default') save_dir="./model_param",
load_epoch=-1,
data_mode="default",
)
return parser.parse_args() return parser.parse_args()
def prepare_data(dataset, prog_args, train=False, pre_process=None): def prepare_data(dataset, prog_args, train=False, pre_process=None):
''' """
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
''' """
if train: if train:
shuffle = True shuffle = True
else: else:
...@@ -111,16 +148,18 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None): ...@@ -111,16 +148,18 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
pre_process(dataset, prog_args) pre_process(dataset, prog_args)
# dataset.set_fold(fold) # dataset.set_fold(fold)
return dgl.dataloading.GraphDataLoader(dataset, return dgl.dataloading.GraphDataLoader(
batch_size=prog_args.batch_size, dataset,
shuffle=shuffle, batch_size=prog_args.batch_size,
num_workers=prog_args.n_worker) shuffle=shuffle,
num_workers=prog_args.n_worker,
)
def graph_classify_task(prog_args): def graph_classify_task(prog_args):
''' """
perform graph classification task perform graph classification task
''' """
dataset = tu.LegacyTUDataset(name=prog_args.dataset) dataset = tu.LegacyTUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset)) train_size = int(prog_args.train_ratio * len(dataset))
...@@ -128,13 +167,17 @@ def graph_classify_task(prog_args): ...@@ -128,13 +167,17 @@ def graph_classify_task(prog_args):
val_size = int(len(dataset) - train_size - test_size) val_size = int(len(dataset) - train_size - test_size)
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split( dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
dataset, (train_size, val_size, test_size)) dataset, (train_size, val_size, test_size)
train_dataloader = prepare_data(dataset_train, prog_args, train=True, )
pre_process=pre_process) train_dataloader = prepare_data(
val_dataloader = prepare_data(dataset_val, prog_args, train=False, dataset_train, prog_args, train=True, pre_process=pre_process
pre_process=pre_process) )
test_dataloader = prepare_data(dataset_test, prog_args, train=False, val_dataloader = prepare_data(
pre_process=pre_process) dataset_val, prog_args, train=False, pre_process=pre_process
)
test_dataloader = prepare_data(
dataset_test, prog_args, train=False, pre_process=pre_process
)
input_dim, label_dim, max_num_node = dataset.statistics() input_dim, label_dim, max_num_node = dataset.statistics()
print("++++++++++STATISTICS ABOUT THE DATASET") print("++++++++++STATISTICS ABOUT THE DATASET")
print("dataset feature dimension is", input_dim) print("dataset feature dimension is", input_dim)
...@@ -157,23 +200,32 @@ def graph_classify_task(prog_args): ...@@ -157,23 +200,32 @@ def graph_classify_task(prog_args):
# initialize model # initialize model
# 'diffpool' : diffpool # 'diffpool' : diffpool
model = DiffPool(input_dim, model = DiffPool(
hidden_dim, input_dim,
embedding_dim, hidden_dim,
label_dim, embedding_dim,
activation, label_dim,
prog_args.gc_per_block, activation,
prog_args.dropout, prog_args.gc_per_block,
prog_args.num_pool, prog_args.dropout,
prog_args.linkpred, prog_args.num_pool,
prog_args.batch_size, prog_args.linkpred,
'meanpool', prog_args.batch_size,
assign_dim, "meanpool",
prog_args.pool_ratio) assign_dim,
prog_args.pool_ratio,
)
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None: if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset model.load_state_dict(
+ "/model.iter-" + str(prog_args.load_epoch))) torch.load(
prog_args.save_dir
+ "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(prog_args.load_epoch)
)
)
print("model init finished") print("model init finished")
print("MODEL:::::::", prog_args.method) print("MODEL:::::::", prog_args.method)
...@@ -181,24 +233,23 @@ def graph_classify_task(prog_args): ...@@ -181,24 +233,23 @@ def graph_classify_task(prog_args):
model = model.cuda() model = model.cuda()
logger = train( logger = train(
train_dataloader, train_dataloader, model, prog_args, val_dataset=val_dataloader
model, )
prog_args,
val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger) result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {:.2f}%".format(result * 100)) print("test accuracy {:.2f}%".format(result * 100))
def train(dataset, model, prog_args, same_feat=True, val_dataset=None): def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
''' """
training function training function
''' """
dir = prog_args.save_dir + "/" + prog_args.dataset dir = prog_args.save_dir + "/" + prog_args.dataset
if not os.path.exists(dir): if not os.path.exists(dir):
os.makedirs(dir) os.makedirs(dir)
dataloader = dataset dataloader = dataset
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, optimizer = torch.optim.Adam(
model.parameters()), lr=0.001) filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
)
early_stopping_logger = {"best_epoch": -1, "val_acc": -1} early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
if prog_args.cuda > 0: if prog_args.cuda > 0:
...@@ -233,34 +284,59 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -233,34 +284,59 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
optimizer.step() optimizer.step()
train_accu = accum_correct / total train_accu = accum_correct / total
print("train accuracy for this epoch {} is {:.2f}%".format(epoch, print(
train_accu * 100)) "train accuracy for this epoch {} is {:.2f}%".format(
epoch, train_accu * 100
)
)
elapsed_time = time.time() - begin_time elapsed_time = time.time() - begin_time
print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format( print(
loss.item(), elapsed_time, computation_time)) "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
loss.item(), elapsed_time, computation_time
)
)
global_train_time_per_epoch.append(elapsed_time) global_train_time_per_epoch.append(elapsed_time)
if val_dataset is not None: if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args) result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {:.2f}%".format(result * 100)) print("validation accuracy {:.2f}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\ if (
train_accu: result >= early_stopping_logger["val_acc"]
and result <= train_accu
):
early_stopping_logger.update(best_epoch=epoch, val_acc=result) early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None: if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset torch.save(
+ "/model.iter-" + str(early_stopping_logger['best_epoch'])) model.state_dict(),
print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'], prog_args.save_dir
early_stopping_logger['val_acc'] * 100)) + "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(early_stopping_logger["best_epoch"]),
)
print(
"best epoch is EPOCH {}, val_acc is {:.2f}%".format(
early_stopping_logger["best_epoch"],
early_stopping_logger["val_acc"] * 100,
)
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return early_stopping_logger return early_stopping_logger
def evaluate(dataloader, model, prog_args, logger=None): def evaluate(dataloader, model, prog_args, logger=None):
''' """
evaluate function evaluate function
''' """
if logger is not None and prog_args.save_dir is not None: if logger is not None and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset model.load_state_dict(
+ "/model.iter-" + str(logger['best_epoch']))) torch.load(
prog_args.save_dir
+ "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(logger["best_epoch"])
)
)
model.eval() model.eval()
correct_label = 0 correct_label = 0
with torch.no_grad(): with torch.no_grad():
...@@ -280,15 +356,23 @@ def evaluate(dataloader, model, prog_args, logger=None): ...@@ -280,15 +356,23 @@ def evaluate(dataloader, model, prog_args, logger=None):
def main(): def main():
''' """
main main
''' """
prog_args = arg_parse() prog_args = arg_parse()
print(prog_args) print(prog_args)
graph_classify_task(prog_args) graph_classify_task(prog_args)
print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) )) print(
print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024))) "Train time per epoch: {:.4f}".format(
sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
)
)
print(
"Max memory usage: {:.4f}".format(
torch.cuda.max_memory_allocated(0) / (1024 * 1024)
)
)
if __name__ == "__main__": if __name__ == "__main__":
......
import os
from pathlib import Path
import click
import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn as nn import torch.nn as nn
import click
import numpy as np
import os
from logzero import logger from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from modules.initializers import GlorotOrthogonal
from modules.dimenet_pp import DimeNetPP from modules.dimenet_pp import DimeNetPP
from modules.initializers import GlorotOrthogonal
from ruamel.yaml import YAML
@click.command() @click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.') @click.option(
@click.option('-c', '--convert-cnf', type=click.Path(exists=True), help='Path of convert config yaml.') "-m",
"--model-cnf",
type=click.Path(exists=True),
help="Path of model config yaml.",
)
@click.option(
"-c",
"--convert-cnf",
type=click.Path(exists=True),
help="Path of convert config yaml.",
)
def main(model_cnf, convert_cnf): def main(model_cnf, convert_cnf):
yaml = YAML(typ='safe') yaml = YAML(typ="safe")
model_cnf = yaml.load(Path(model_cnf)) model_cnf = yaml.load(Path(model_cnf))
convert_cnf = yaml.load(Path(convert_cnf)) convert_cnf = yaml.load(Path(convert_cnf))
model_name, model_params, _ = model_cnf['name'], model_cnf['model'], model_cnf['train'] model_name, model_params, _ = (
logger.info(f'Model name: {model_name}') model_cnf["name"],
logger.info(f'Model params: {model_params}') model_cnf["model"],
model_cnf["train"],
)
logger.info(f"Model name: {model_name}")
logger.info(f"Model params: {model_params}")
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']: if model_params["targets"] in ["mu", "homo", "lumo", "gap", "zpve"]:
model_params['output_init'] = nn.init.zeros_ model_params["output_init"] = nn.init.zeros_
else: else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal model_params["output_init"] = GlorotOrthogonal
# model initialization # model initialization
logger.info('Loading Model') logger.info("Loading Model")
model = DimeNetPP(emb_size=model_params['emb_size'], model = DimeNetPP(
out_emb_size=model_params['out_emb_size'], emb_size=model_params["emb_size"],
int_emb_size=model_params['int_emb_size'], out_emb_size=model_params["out_emb_size"],
basis_emb_size=model_params['basis_emb_size'], int_emb_size=model_params["int_emb_size"],
num_blocks=model_params['num_blocks'], basis_emb_size=model_params["basis_emb_size"],
num_spherical=model_params['num_spherical'], num_blocks=model_params["num_blocks"],
num_radial=model_params['num_radial'], num_spherical=model_params["num_spherical"],
cutoff=model_params['cutoff'], num_radial=model_params["num_radial"],
envelope_exponent=model_params['envelope_exponent'], cutoff=model_params["cutoff"],
num_before_skip=model_params['num_before_skip'], envelope_exponent=model_params["envelope_exponent"],
num_after_skip=model_params['num_after_skip'], num_before_skip=model_params["num_before_skip"],
num_dense_output=model_params['num_dense_output'], num_after_skip=model_params["num_after_skip"],
num_targets=len(model_params['targets']), num_dense_output=model_params["num_dense_output"],
extensive=model_params['extensive'], num_targets=len(model_params["targets"]),
output_init=model_params['output_init']) extensive=model_params["extensive"],
output_init=model_params["output_init"],
)
logger.info(model.state_dict()) logger.info(model.state_dict())
tf_path, torch_path = convert_cnf['tf']['ckpt_path'], convert_cnf['torch']['dump_path'] tf_path, torch_path = (
convert_cnf["tf"]["ckpt_path"],
convert_cnf["torch"]["dump_path"],
)
init_vars = tf.train.list_variables(tf_path) init_vars = tf.train.list_variables(tf_path)
tf_vars_dict = {} tf_vars_dict = {}
# 147 keys # 147 keys
for name, shape in init_vars: for name, shape in init_vars:
if name == '_CHECKPOINTABLE_OBJECT_GRAPH': if name == "_CHECKPOINTABLE_OBJECT_GRAPH":
continue continue
array = tf.train.load_variable(tf_path, name) array = tf.train.load_variable(tf_path, name)
logger.info(f'Loading TF weight {name} with shape {shape}') logger.info(f"Loading TF weight {name} with shape {shape}")
tf_vars_dict[name] = array tf_vars_dict[name] = array
for name, array in tf_vars_dict.items(): for name, array in tf_vars_dict.items():
name = name.split('/')[:-2] name = name.split("/")[:-2]
pointer = model pointer = model
for m_name in name: for m_name in name:
if m_name == 'kernel': if m_name == "kernel":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif m_name == 'int_blocks': elif m_name == "int_blocks":
pointer = getattr(pointer, 'interaction_blocks') pointer = getattr(pointer, "interaction_blocks")
elif m_name == 'embeddings': elif m_name == "embeddings":
pointer = getattr(pointer, 'embedding') pointer = getattr(pointer, "embedding")
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
else: else:
pointer = getattr(pointer, m_name) pointer = getattr(pointer, m_name)
if name[-1] == 'kernel': if name[-1] == "kernel":
array = np.transpose(array) array = np.transpose(array)
assert array.shape == pointer.shape assert array.shape == pointer.shape
logger.info(f'Initialize PyTorch weight {name}') logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
logger.info(f'Save PyTorch model to {torch_path}') logger.info(f"Save PyTorch model to {torch_path}")
if not os.path.exists(torch_path): if not os.path.exists(torch_path):
os.makedirs(torch_path) os.makedirs(torch_path)
target = model_params['targets'][0] target = model_params["targets"][0]
torch.save(model.state_dict(), f'{torch_path}/{target}.pt') torch.save(model.state_dict(), f"{torch_path}/{target}.pt")
logger.info(model.state_dict()) logger.info(model.state_dict())
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
import click
import copy import copy
from pathlib import Path
import click
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import dgl import torch.optim as optim
from logzero import logger from logzero import logger
from pathlib import Path from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
from modules.initializers import GlorotOrthogonal
from qm9 import QM9
from ruamel.yaml import YAML from ruamel.yaml import YAML
from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import Subset from dgl.data.utils import Subset
from sklearn.metrics import mean_absolute_error
from qm9 import QM9
from modules.initializers import GlorotOrthogonal
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=None):
def split_dataset(
dataset, num_train, num_valid, shuffle=False, random_state=None
):
"""Split dataset into training, validation and test set. """Split dataset into training, validation and test set.
Parameters Parameters
...@@ -46,6 +50,7 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non ...@@ -46,6 +50,7 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non
Subsets for training, validation and test. Subsets for training, validation and test.
""" """
from itertools import accumulate from itertools import accumulate
num_data = len(dataset) num_data = len(dataset)
assert num_train + num_valid < num_data assert num_train + num_valid < num_data
lengths = [num_train, num_valid, num_data - num_train - num_valid] lengths = [num_train, num_valid, num_data - num_train - num_valid]
...@@ -53,20 +58,26 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non ...@@ -53,20 +58,26 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non
indices = np.random.RandomState(seed=random_state).permutation(num_data) indices = np.random.RandomState(seed=random_state).permutation(num_data)
else: else:
indices = np.arange(num_data) indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)] return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(accumulate(lengths), lengths)
]
@torch.no_grad() @torch.no_grad()
def ema(ema_model, model, decay): def ema(ema_model, model, decay):
msd = model.state_dict() msd = model.state_dict()
for k, ema_v in ema_model.state_dict().items(): for k, ema_v in ema_model.state_dict().items():
model_v = msd[k].detach() model_v = msd[k].detach()
ema_v.copy_(ema_v * decay + (1. - decay) * model_v) ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def edge_init(edges): def edge_init(edges):
R_src, R_dst = edges.src['R'], edges.dst['R'] R_src, R_dst = edges.src["R"], edges.dst["R"]
dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1))) dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))
# d: bond length, o: bond orientation # d: bond length, o: bond orientation
return {'d': dist, 'o': R_src - R_dst} return {"d": dist, "o": R_src - R_dst}
def _collate_fn(batch): def _collate_fn(batch):
graphs, line_graphs, labels = map(list, zip(*batch)) graphs, line_graphs, labels = map(list, zip(*batch))
...@@ -74,6 +85,7 @@ def _collate_fn(batch): ...@@ -74,6 +85,7 @@ def _collate_fn(batch):
labels = torch.tensor(labels, dtype=torch.float32) labels = torch.tensor(labels, dtype=torch.float32)
return g, l_g, labels return g, l_g, labels
def train(device, model, opt, loss_fn, train_loader): def train(device, model, opt, loss_fn, train_loader):
model.train() model.train()
epoch_loss = 0 epoch_loss = 0
...@@ -93,162 +105,202 @@ def train(device, model, opt, loss_fn, train_loader): ...@@ -93,162 +105,202 @@ def train(device, model, opt, loss_fn, train_loader):
return epoch_loss / num_samples return epoch_loss / num_samples
@torch.no_grad() @torch.no_grad()
def evaluate(device, model, valid_loader): def evaluate(device, model, valid_loader):
model.eval() model.eval()
predictions_all, labels_all = [], [] predictions_all, labels_all = [], []
for g, l_g, labels in valid_loader: for g, l_g, labels in valid_loader:
g = g.to(device) g = g.to(device)
l_g = l_g.to(device) l_g = l_g.to(device)
logits = model(g, l_g) logits = model(g, l_g)
labels_all.extend(labels) labels_all.extend(labels)
predictions_all.extend(logits.view(-1,).cpu().numpy()) predictions_all.extend(
logits.view(
-1,
)
.cpu()
.numpy()
)
return np.array(predictions_all), np.array(labels_all) return np.array(predictions_all), np.array(labels_all)
@click.command() @click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.') @click.option(
"-m",
"--model-cnf",
type=click.Path(exists=True),
help="Path of model config yaml.",
)
def main(model_cnf): def main(model_cnf):
yaml = YAML(typ='safe') yaml = YAML(typ="safe")
model_cnf = yaml.load(Path(model_cnf)) model_cnf = yaml.load(Path(model_cnf))
model_name, model_params, train_params, pretrain_params = model_cnf['name'], model_cnf['model'], model_cnf['train'], model_cnf['pretrain'] model_name, model_params, train_params, pretrain_params = (
logger.info(f'Model name: {model_name}') model_cnf["name"],
logger.info(f'Model params: {model_params}') model_cnf["model"],
logger.info(f'Train params: {train_params}') model_cnf["train"],
model_cnf["pretrain"],
)
logger.info(f"Model name: {model_name}")
logger.info(f"Model params: {model_params}")
logger.info(f"Train params: {train_params}")
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']: if model_params["targets"] in ["mu", "homo", "lumo", "gap", "zpve"]:
model_params['output_init'] = nn.init.zeros_ model_params["output_init"] = nn.init.zeros_
else: else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal model_params["output_init"] = GlorotOrthogonal
logger.info('Loading Data Set') logger.info("Loading Data Set")
dataset = QM9(label_keys=model_params['targets'], edge_funcs=[edge_init]) dataset = QM9(label_keys=model_params["targets"], edge_funcs=[edge_init])
# data split # data split
train_data, valid_data, test_data = split_dataset(dataset, train_data, valid_data, test_data = split_dataset(
num_train=train_params['num_train'], dataset,
num_valid=train_params['num_valid'], num_train=train_params["num_train"],
shuffle=True, num_valid=train_params["num_valid"],
random_state=train_params['data_seed']) shuffle=True,
logger.info(f'Size of Training Set: {len(train_data)}') random_state=train_params["data_seed"],
logger.info(f'Size of Validation Set: {len(valid_data)}') )
logger.info(f'Size of Test Set: {len(test_data)}') logger.info(f"Size of Training Set: {len(train_data)}")
logger.info(f"Size of Validation Set: {len(valid_data)}")
logger.info(f"Size of Test Set: {len(test_data)}")
# data loader # data loader
train_loader = DataLoader(train_data, train_loader = DataLoader(
batch_size=train_params['batch_size'], train_data,
shuffle=True, batch_size=train_params["batch_size"],
collate_fn=_collate_fn, shuffle=True,
num_workers=train_params['num_workers']) collate_fn=_collate_fn,
num_workers=train_params["num_workers"],
valid_loader = DataLoader(valid_data, )
batch_size=train_params['batch_size'],
shuffle=False, valid_loader = DataLoader(
collate_fn=_collate_fn, valid_data,
num_workers=train_params['num_workers']) batch_size=train_params["batch_size"],
shuffle=False,
test_loader = DataLoader(test_data, collate_fn=_collate_fn,
batch_size=train_params['batch_size'], num_workers=train_params["num_workers"],
shuffle=False, )
collate_fn=_collate_fn,
num_workers=train_params['num_workers']) test_loader = DataLoader(
test_data,
batch_size=train_params["batch_size"],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params["num_workers"],
)
# check cuda # check cuda
gpu = train_params['gpu'] gpu = train_params["gpu"]
device = f'cuda:{gpu}' if gpu >= 0 and torch.cuda.is_available() else 'cpu' device = f"cuda:{gpu}" if gpu >= 0 and torch.cuda.is_available() else "cpu"
# model initialization # model initialization
logger.info('Loading Model') logger.info("Loading Model")
if model_name == 'dimenet': if model_name == "dimenet":
model = DimeNet(emb_size=model_params['emb_size'], model = DimeNet(
num_blocks=model_params['num_blocks'], emb_size=model_params["emb_size"],
num_bilinear=model_params['num_bilinear'], num_blocks=model_params["num_blocks"],
num_spherical=model_params['num_spherical'], num_bilinear=model_params["num_bilinear"],
num_radial=model_params['num_radial'], num_spherical=model_params["num_spherical"],
cutoff=model_params['cutoff'], num_radial=model_params["num_radial"],
envelope_exponent=model_params['envelope_exponent'], cutoff=model_params["cutoff"],
num_before_skip=model_params['num_before_skip'], envelope_exponent=model_params["envelope_exponent"],
num_after_skip=model_params['num_after_skip'], num_before_skip=model_params["num_before_skip"],
num_dense_output=model_params['num_dense_output'], num_after_skip=model_params["num_after_skip"],
num_targets=len(model_params['targets']), num_dense_output=model_params["num_dense_output"],
output_init=model_params['output_init']).to(device) num_targets=len(model_params["targets"]),
elif model_name == 'dimenet++': output_init=model_params["output_init"],
model = DimeNetPP(emb_size=model_params['emb_size'], ).to(device)
out_emb_size=model_params['out_emb_size'], elif model_name == "dimenet++":
int_emb_size=model_params['int_emb_size'], model = DimeNetPP(
basis_emb_size=model_params['basis_emb_size'], emb_size=model_params["emb_size"],
num_blocks=model_params['num_blocks'], out_emb_size=model_params["out_emb_size"],
num_spherical=model_params['num_spherical'], int_emb_size=model_params["int_emb_size"],
num_radial=model_params['num_radial'], basis_emb_size=model_params["basis_emb_size"],
cutoff=model_params['cutoff'], num_blocks=model_params["num_blocks"],
envelope_exponent=model_params['envelope_exponent'], num_spherical=model_params["num_spherical"],
num_before_skip=model_params['num_before_skip'], num_radial=model_params["num_radial"],
num_after_skip=model_params['num_after_skip'], cutoff=model_params["cutoff"],
num_dense_output=model_params['num_dense_output'], envelope_exponent=model_params["envelope_exponent"],
num_targets=len(model_params['targets']), num_before_skip=model_params["num_before_skip"],
extensive=model_params['extensive'], num_after_skip=model_params["num_after_skip"],
output_init=model_params['output_init']).to(device) num_dense_output=model_params["num_dense_output"],
num_targets=len(model_params["targets"]),
extensive=model_params["extensive"],
output_init=model_params["output_init"],
).to(device)
else: else:
raise ValueError(f'Invalid Model Name {model_name}') raise ValueError(f"Invalid Model Name {model_name}")
if pretrain_params['flag']: if pretrain_params["flag"]:
torch_path = pretrain_params['path'] torch_path = pretrain_params["path"]
target = model_params['targets'][0] target = model_params["targets"][0]
model.load_state_dict(torch.load(f'{torch_path}/{target}.pt')) model.load_state_dict(torch.load(f"{torch_path}/{target}.pt"))
logger.info('Testing with Pretrained model') logger.info("Testing with Pretrained model")
predictions, labels = evaluate(device, model, test_loader) predictions, labels = evaluate(device, model, test_loader)
test_mae = mean_absolute_error(labels, predictions) test_mae = mean_absolute_error(labels, predictions)
logger.info(f'Test MAE {test_mae:.4f}') logger.info(f"Test MAE {test_mae:.4f}")
return return
# define loss function and optimization # define loss function and optimization
loss_fn = nn.L1Loss() loss_fn = nn.L1Loss()
opt = optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'], amsgrad=True) opt = optim.Adam(
scheduler = optim.lr_scheduler.StepLR(opt, train_params['step_size'], gamma=train_params['gamma']) model.parameters(),
lr=train_params["lr"],
weight_decay=train_params["weight_decay"],
amsgrad=True,
)
scheduler = optim.lr_scheduler.StepLR(
opt, train_params["step_size"], gamma=train_params["gamma"]
)
# model training # model training
best_mae = 1e9 best_mae = 1e9
no_improvement = 0 no_improvement = 0
# EMA for valid and test # EMA for valid and test
logger.info('EMA Init') logger.info("EMA Init")
ema_model = copy.deepcopy(model) ema_model = copy.deepcopy(model)
for p in ema_model.parameters(): for p in ema_model.parameters():
p.requires_grad_(False) p.requires_grad_(False)
best_model = copy.deepcopy(ema_model) best_model = copy.deepcopy(ema_model)
logger.info('Training') logger.info("Training")
for i in range(train_params['epochs']): for i in range(train_params["epochs"]):
train_loss = train(device, model, opt, loss_fn, train_loader) train_loss = train(device, model, opt, loss_fn, train_loader)
ema(ema_model, model, train_params['ema_decay']) ema(ema_model, model, train_params["ema_decay"])
if i % train_params['interval'] == 0: if i % train_params["interval"] == 0:
predictions, labels = evaluate(device, ema_model, valid_loader) predictions, labels = evaluate(device, ema_model, valid_loader)
valid_mae = mean_absolute_error(labels, predictions) valid_mae = mean_absolute_error(labels, predictions)
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}') logger.info(
f"Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}"
)
if valid_mae > best_mae: if valid_mae > best_mae:
no_improvement += 1 no_improvement += 1
if no_improvement == train_params['early_stopping']: if no_improvement == train_params["early_stopping"]:
logger.info('Early stop.') logger.info("Early stop.")
break break
else: else:
no_improvement = 0 no_improvement = 0
best_mae = valid_mae best_mae = valid_mae
best_model = copy.deepcopy(ema_model) best_model = copy.deepcopy(ema_model)
else: else:
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f}') logger.info(f"Epoch {i} | Train Loss {train_loss:.4f}")
scheduler.step() scheduler.step()
logger.info('Testing') logger.info("Testing")
predictions, labels = evaluate(device, best_model, test_loader) predictions, labels = evaluate(device, best_model, test_loader)
test_mae = mean_absolute_error(labels, predictions) test_mae = mean_absolute_error(labels, predictions)
logger.info('Test MAE {:.4f}'.format(test_mae)) logger.info("Test MAE {:.4f}".format(test_mae))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import torch import torch
def swish(x): def swish(x):
""" """
Swish activation function, Swish activation function,
from Ramachandran, Zopf, Le 2017. "Searching for Activation Functions" from Ramachandran, Zopf, Le 2017. "Searching for Activation Functions"
""" """
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
\ No newline at end of file
import numpy as np import numpy as np
import sympy as sym import sympy as sym
from scipy.optimize import brentq
from scipy import special as sp from scipy import special as sp
from scipy.optimize import brentq
def Jn(r, n): def Jn(r, n):
""" """
...@@ -19,6 +19,7 @@ def Jn(r, n): ...@@ -19,6 +19,7 @@ def Jn(r, n):
""" """
return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) # the same shape as n return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) # the same shape as n
def Jn_zeros(n, k): def Jn_zeros(n, k):
""" """
n: int n: int
...@@ -40,6 +41,7 @@ def Jn_zeros(n, k): ...@@ -40,6 +41,7 @@ def Jn_zeros(n, k):
return zerosj return zerosj
def spherical_bessel_formulas(n): def spherical_bessel_formulas(n):
""" """
n: int n: int
...@@ -48,7 +50,7 @@ def spherical_bessel_formulas(n): ...@@ -48,7 +50,7 @@ def spherical_bessel_formulas(n):
n sympy functions n sympy functions
Computes the sympy formulas for the spherical bessel functions up to order n (excluded) Computes the sympy formulas for the spherical bessel functions up to order n (excluded)
""" """
x = sym.symbols('x') x = sym.symbols("x")
f = [sym.sin(x) / x] f = [sym.sin(x) / x]
a = sym.sin(x) / x a = sym.sin(x) / x
...@@ -58,6 +60,7 @@ def spherical_bessel_formulas(n): ...@@ -58,6 +60,7 @@ def spherical_bessel_formulas(n):
a = sym.simplify(b) a = sym.simplify(b)
return f return f
def bessel_basis(n, k): def bessel_basis(n, k):
""" """
n: int n: int
...@@ -79,26 +82,36 @@ def bessel_basis(n, k): ...@@ -79,26 +82,36 @@ def bessel_basis(n, k):
normalizer += [normalizer_tmp] normalizer += [normalizer_tmp]
f = spherical_bessel_formulas(n) f = spherical_bessel_formulas(n)
x = sym.symbols('x') x = sym.symbols("x")
bess_basis = [] bess_basis = []
for order in range(n): for order in range(n):
bess_basis_tmp = [] bess_basis_tmp = []
for i in range(k): for i in range(k):
bess_basis_tmp += [sym.simplify(normalizer[order][i] * f[order].subs(x, zeros[order, i] * x))] bess_basis_tmp += [
sym.simplify(
normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)
)
]
bess_basis += [bess_basis_tmp] bess_basis += [bess_basis_tmp]
return bess_basis return bess_basis
def sph_harm_prefactor(l, m): def sph_harm_prefactor(l, m):
""" """
l: int l: int
m: int m: int
res: float res: float
Computes the constant pre-factor for the spherical harmonic of degree l and order m Computes the constant pre-factor for the spherical harmonic of degree l and order m
input: input:
l: int, l>=0 l: int, l>=0
m: int, -l<=m<=l m: int, -l<=m<=l
""" """
return ((2 * l + 1) * np.math.factorial(l - abs(m)) / (4 * np.pi * np.math.factorial(l + abs(m)))) ** 0.5 return (
(2 * l + 1)
* np.math.factorial(l - abs(m))
/ (4 * np.pi * np.math.factorial(l + abs(m)))
) ** 0.5
def associated_legendre_polynomials(l, zero_m_only=True): def associated_legendre_polynomials(l, zero_m_only=True):
""" """
...@@ -106,7 +119,7 @@ def associated_legendre_polynomials(l, zero_m_only=True): ...@@ -106,7 +119,7 @@ def associated_legendre_polynomials(l, zero_m_only=True):
return: l sympy functions return: l sympy functions
Computes sympy formulas of the associated legendre polynomials up to order l (excluded). Computes sympy formulas of the associated legendre polynomials up to order l (excluded).
""" """
z = sym.symbols('z') z = sym.symbols("z")
P_l_m = [[0] * (j + 1) for j in range(l)] P_l_m = [[0] * (j + 1) for j in range(l)]
P_l_m[0][0] = 1 P_l_m[0][0] = 1
...@@ -116,18 +129,29 @@ def associated_legendre_polynomials(l, zero_m_only=True): ...@@ -116,18 +129,29 @@ def associated_legendre_polynomials(l, zero_m_only=True):
for j in range(2, l): for j in range(2, l):
P_l_m[j][0] = sym.simplify( P_l_m[j][0] = sym.simplify(
((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j) ((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0])
/ j
)
if not zero_m_only: if not zero_m_only:
for i in range(1, l): for i in range(1, l):
P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
if i + 1 < l: if i + 1 < l:
P_l_m[i + 1][i] = sym.simplify((2 * i + 1) * z * P_l_m[i][i]) P_l_m[i + 1][i] = sym.simplify(
(2 * i + 1) * z * P_l_m[i][i]
)
for j in range(i + 2, l): for j in range(i + 2, l):
P_l_m[j][i] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) P_l_m[j][i] = sym.simplify(
(
(2 * j - 1) * z * P_l_m[j - 1][i]
- (i + j - 1) * P_l_m[j - 2][i]
)
/ (j - i)
)
return P_l_m return P_l_m
def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
""" """
return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1) return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1)
...@@ -138,30 +162,38 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): ...@@ -138,30 +162,38 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
S_m = [0] S_m = [0]
C_m = [1] C_m = [1]
for i in range(1, l): for i in range(1, l):
x = sym.symbols('x') x = sym.symbols("x")
y = sym.symbols('y') y = sym.symbols("y")
S_m += [x * S_m[i - 1] + y * C_m[i - 1]] S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
C_m += [x * C_m[i - 1] - y * S_m[i - 1]] C_m += [x * C_m[i - 1] - y * S_m[i - 1]]
P_l_m = associated_legendre_polynomials(l, zero_m_only) P_l_m = associated_legendre_polynomials(l, zero_m_only)
if spherical_coordinates: if spherical_coordinates:
theta = sym.symbols('theta') theta = sym.symbols("theta")
z = sym.symbols('z') z = sym.symbols("z")
for i in range(len(P_l_m)): for i in range(len(P_l_m)):
for j in range(len(P_l_m[i])): for j in range(len(P_l_m[i])):
if type(P_l_m[i][j]) != int: if type(P_l_m[i][j]) != int:
P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
if not zero_m_only: if not zero_m_only:
phi = sym.symbols('phi') phi = sym.symbols("phi")
for i in range(len(S_m)): for i in range(len(S_m)):
S_m[i] = S_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) S_m[i] = (
S_m[i]
.subs(x, sym.sin(theta) * sym.cos(phi))
.subs(y, sym.sin(theta) * sym.sin(phi))
)
for i in range(len(C_m)): for i in range(len(C_m)):
C_m[i] = C_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) C_m[i] = (
C_m[i]
.subs(x, sym.sin(theta) * sym.cos(phi))
.subs(y, sym.sin(theta) * sym.sin(phi))
)
Y_func_l_m = [['0'] * (2 * j + 1) for j in range(l)] Y_func_l_m = [["0"] * (2 * j + 1) for j in range(l)]
for i in range(l): for i in range(l):
Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
...@@ -169,9 +201,13 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): ...@@ -169,9 +201,13 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
if not zero_m_only: if not zero_m_only:
for i in range(1, l): for i in range(1, l):
for j in range(1, i + 1): for j in range(1, i + 1):
Y_func_l_m[i][j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) Y_func_l_m[i][j] = sym.simplify(
2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]
)
for i in range(1, l): for i in range(1, l):
for j in range(1, i + 1): for j in range(1, i + 1):
Y_func_l_m[i][-j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) Y_func_l_m[i][-j] = sym.simplify(
2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]
)
return Y_func_l_m return Y_func_l_m
\ No newline at end of file
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.envelope import Envelope from modules.envelope import Envelope
class BesselBasisLayer(nn.Module): class BesselBasisLayer(nn.Module):
def __init__(self, def __init__(self, num_radial, cutoff, envelope_exponent=5):
num_radial,
cutoff,
envelope_exponent=5):
super(BesselBasisLayer, self).__init__() super(BesselBasisLayer, self).__init__()
self.cutoff = cutoff self.cutoff = cutoff
self.envelope = Envelope(envelope_exponent) self.envelope = Envelope(envelope_exponent)
self.frequencies = nn.Parameter(torch.Tensor(num_radial)) self.frequencies = nn.Parameter(torch.Tensor(num_radial))
...@@ -18,13 +15,15 @@ class BesselBasisLayer(nn.Module): ...@@ -18,13 +15,15 @@ class BesselBasisLayer(nn.Module):
def reset_params(self): def reset_params(self):
with torch.no_grad(): with torch.no_grad():
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi) torch.arange(
1, self.frequencies.numel() + 1, out=self.frequencies
).mul_(np.pi)
self.frequencies.requires_grad_() self.frequencies.requires_grad_()
def forward(self, g): def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff d_scaled = g.edata["d"] / self.cutoff
# Necessary for proper broadcasting behaviour # Necessary for proper broadcasting behaviour
d_scaled = torch.unsqueeze(d_scaled, -1) d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled) d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled) g.edata["rbf"] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g return g
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.activations import swish from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock from modules.embedding_block import EmbeddingBlock
from modules.output_block import OutputBlock
from modules.interaction_block import InteractionBlock from modules.interaction_block import InteractionBlock
from modules.output_block import OutputBlock
from modules.spherical_basis_layer import SphericalBasisLayer
class DimeNet(nn.Module): class DimeNet(nn.Module):
""" """
...@@ -41,67 +41,86 @@ class DimeNet(nn.Module): ...@@ -41,67 +41,86 @@ class DimeNet(nn.Module):
output_init output_init
Initial function in output block Initial function in output block
""" """
def __init__(self,
emb_size, def __init__(
num_blocks, self,
num_bilinear, emb_size,
num_spherical, num_blocks,
num_radial, num_bilinear,
cutoff=5.0, num_spherical,
envelope_exponent=5, num_radial,
num_before_skip=1, cutoff=5.0,
num_after_skip=2, envelope_exponent=5,
num_dense_output=3, num_before_skip=1,
num_targets=12, num_after_skip=2,
activation=swish, num_dense_output=3,
output_init=nn.init.zeros_): num_targets=12,
activation=swish,
output_init=nn.init.zeros_,
):
super(DimeNet, self).__init__() super(DimeNet, self).__init__()
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.num_radial = num_radial self.num_radial = num_radial
# cosine basis function expansion layer # cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial, self.rbf_layer = BesselBasisLayer(
cutoff=cutoff, num_radial=num_radial,
envelope_exponent=envelope_exponent) cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(
num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block # embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size, self.emb_block = EmbeddingBlock(
num_radial=num_radial, emb_size=emb_size,
bessel_funcs=self.sbf_layer.get_bessel_funcs(), num_radial=num_radial,
cutoff=cutoff, bessel_funcs=self.sbf_layer.get_bessel_funcs(),
envelope_exponent=envelope_exponent, cutoff=cutoff,
activation=activation) envelope_exponent=envelope_exponent,
activation=activation,
)
# output block # output block
self.output_blocks = nn.ModuleList({ self.output_blocks = nn.ModuleList(
OutputBlock(emb_size=emb_size, {
num_radial=num_radial, OutputBlock(
num_dense=num_dense_output, emb_size=emb_size,
num_targets=num_targets, num_radial=num_radial,
activation=activation, num_dense=num_dense_output,
output_init=output_init) for _ in range(num_blocks + 1) num_targets=num_targets,
}) activation=activation,
output_init=output_init,
)
for _ in range(num_blocks + 1)
}
)
# interaction block # interaction block
self.interaction_blocks = nn.ModuleList({ self.interaction_blocks = nn.ModuleList(
InteractionBlock(emb_size=emb_size, {
num_radial=num_radial, InteractionBlock(
num_spherical=num_spherical, emb_size=emb_size,
num_bilinear=num_bilinear, num_radial=num_radial,
num_before_skip=num_before_skip, num_spherical=num_spherical,
num_after_skip=num_after_skip, num_bilinear=num_bilinear,
activation=activation) for _ in range(num_blocks) num_before_skip=num_before_skip,
}) num_after_skip=num_after_skip,
activation=activation,
)
for _ in range(num_blocks)
}
)
def edge_init(self, edges): def edge_init(self, edges):
# Calculate angles k -> j -> i # Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o'] R1, R2 = edges.src["o"], edges.dst["o"]
x = torch.sum(R1 * R2, dim=-1) x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2) y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1) y = torch.norm(y, dim=-1)
...@@ -110,9 +129,9 @@ class DimeNet(nn.Module): ...@@ -110,9 +129,9 @@ class DimeNet(nn.Module):
cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()] cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
cbf = torch.stack(cbf, dim=1) # [None, 7] cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42] cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
sbf = edges.src['rbf_env'] * cbf # [None, 42] sbf = edges.src["rbf_env"] * cbf # [None, 42]
return {'sbf': sbf} return {"sbf": sbf}
def forward(self, g, l_g): def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,] # add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g) g = self.rbf_layer(g)
...@@ -129,5 +148,5 @@ class DimeNet(nn.Module): ...@@ -129,5 +148,5 @@ class DimeNet(nn.Module):
for i in range(self.num_blocks): for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g) g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g) P += self.output_blocks[i + 1](g)
return P return P
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.activations import swish from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock from modules.embedding_block import EmbeddingBlock
from modules.output_pp_block import OutputPPBlock
from modules.interaction_pp_block import InteractionPPBlock from modules.interaction_pp_block import InteractionPPBlock
from modules.output_pp_block import OutputPPBlock
from modules.spherical_basis_layer import SphericalBasisLayer
class DimeNetPP(nn.Module): class DimeNetPP(nn.Module):
""" """
...@@ -47,73 +47,92 @@ class DimeNetPP(nn.Module): ...@@ -47,73 +47,92 @@ class DimeNetPP(nn.Module):
output_init output_init
Initial function in output block Initial function in output block
""" """
def __init__(self,
emb_size, def __init__(
out_emb_size, self,
int_emb_size, emb_size,
basis_emb_size, out_emb_size,
num_blocks, int_emb_size,
num_spherical, basis_emb_size,
num_radial, num_blocks,
cutoff=5.0, num_spherical,
envelope_exponent=5, num_radial,
num_before_skip=1, cutoff=5.0,
num_after_skip=2, envelope_exponent=5,
num_dense_output=3, num_before_skip=1,
num_targets=12, num_after_skip=2,
activation=swish, num_dense_output=3,
extensive=True, num_targets=12,
output_init=nn.init.zeros_): activation=swish,
extensive=True,
output_init=nn.init.zeros_,
):
super(DimeNetPP, self).__init__() super(DimeNetPP, self).__init__()
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.num_radial = num_radial self.num_radial = num_radial
# cosine basis function expansion layer # cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial, self.rbf_layer = BesselBasisLayer(
cutoff=cutoff, num_radial=num_radial,
envelope_exponent=envelope_exponent) cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(
num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block # embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size, self.emb_block = EmbeddingBlock(
num_radial=num_radial, emb_size=emb_size,
bessel_funcs=self.sbf_layer.get_bessel_funcs(), num_radial=num_radial,
cutoff=cutoff, bessel_funcs=self.sbf_layer.get_bessel_funcs(),
envelope_exponent=envelope_exponent, cutoff=cutoff,
activation=activation) envelope_exponent=envelope_exponent,
activation=activation,
)
# output block # output block
self.output_blocks = nn.ModuleList({ self.output_blocks = nn.ModuleList(
OutputPPBlock(emb_size=emb_size, {
out_emb_size=out_emb_size, OutputPPBlock(
num_radial=num_radial, emb_size=emb_size,
num_dense=num_dense_output, out_emb_size=out_emb_size,
num_targets=num_targets, num_radial=num_radial,
activation=activation, num_dense=num_dense_output,
extensive=extensive, num_targets=num_targets,
output_init=output_init) for _ in range(num_blocks + 1) activation=activation,
}) extensive=extensive,
output_init=output_init,
)
for _ in range(num_blocks + 1)
}
)
# interaction block # interaction block
self.interaction_blocks = nn.ModuleList({ self.interaction_blocks = nn.ModuleList(
InteractionPPBlock(emb_size=emb_size, {
int_emb_size=int_emb_size, InteractionPPBlock(
basis_emb_size=basis_emb_size, emb_size=emb_size,
num_radial=num_radial, int_emb_size=int_emb_size,
num_spherical=num_spherical, basis_emb_size=basis_emb_size,
num_before_skip=num_before_skip, num_radial=num_radial,
num_after_skip=num_after_skip, num_spherical=num_spherical,
activation=activation) for _ in range(num_blocks) num_before_skip=num_before_skip,
}) num_after_skip=num_after_skip,
activation=activation,
)
for _ in range(num_blocks)
}
)
def edge_init(self, edges): def edge_init(self, edges):
# Calculate angles k -> j -> i # Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o'] R1, R2 = edges.src["o"], edges.dst["o"]
x = torch.sum(R1 * R2, dim=-1) x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2) y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1) y = torch.norm(y, dim=-1)
...@@ -123,9 +142,9 @@ class DimeNetPP(nn.Module): ...@@ -123,9 +142,9 @@ class DimeNetPP(nn.Module):
cbf = torch.stack(cbf, dim=1) # [None, 7] cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42] cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
# Notice: it's dst, not src # Notice: it's dst, not src
sbf = edges.dst['rbf_env'] * cbf # [None, 42] sbf = edges.dst["rbf_env"] * cbf # [None, 42]
return {'sbf': sbf} return {"sbf": sbf}
def forward(self, g, l_g): def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,] # add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g) g = self.rbf_layer(g)
...@@ -142,5 +161,5 @@ class DimeNetPP(nn.Module): ...@@ -142,5 +161,5 @@ class DimeNetPP(nn.Module):
for i in range(self.num_blocks): for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g) g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g) P += self.output_blocks[i + 1](g)
return P return P
\ No newline at end of file
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.envelope import Envelope from modules.envelope import Envelope
from modules.initializers import GlorotOrthogonal from modules.initializers import GlorotOrthogonal
class EmbeddingBlock(nn.Module): class EmbeddingBlock(nn.Module):
def __init__(self, def __init__(
emb_size, self,
num_radial, emb_size,
bessel_funcs, num_radial,
cutoff, bessel_funcs,
envelope_exponent, cutoff,
num_atom_types=95, envelope_exponent,
activation=None): num_atom_types=95,
activation=None,
):
super(EmbeddingBlock, self).__init__() super(EmbeddingBlock, self).__init__()
self.bessel_funcs = bessel_funcs self.bessel_funcs = bessel_funcs
...@@ -24,35 +26,35 @@ class EmbeddingBlock(nn.Module): ...@@ -24,35 +26,35 @@ class EmbeddingBlock(nn.Module):
self.dense_rbf = nn.Linear(num_radial, emb_size) self.dense_rbf = nn.Linear(num_radial, emb_size)
self.dense = nn.Linear(emb_size * 3, emb_size) self.dense = nn.Linear(emb_size * 3, emb_size)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
nn.init.uniform_(self.embedding.weight, a=-np.sqrt(3), b=np.sqrt(3)) nn.init.uniform_(self.embedding.weight, a=-np.sqrt(3), b=np.sqrt(3))
GlorotOrthogonal(self.dense_rbf.weight) GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense.weight) GlorotOrthogonal(self.dense.weight)
def edge_init(self, edges): def edge_init(self, edges):
""" msg emb init """ """msg emb init"""
# m init # m init
rbf = self.dense_rbf(edges.data['rbf']) rbf = self.dense_rbf(edges.data["rbf"])
if self.activation is not None: if self.activation is not None:
rbf = self.activation(rbf) rbf = self.activation(rbf)
m = torch.cat([edges.src['h'], edges.dst['h'], rbf], dim=-1) m = torch.cat([edges.src["h"], edges.dst["h"], rbf], dim=-1)
m = self.dense(m) m = self.dense(m)
if self.activation is not None: if self.activation is not None:
m = self.activation(m) m = self.activation(m)
# rbf_env init # rbf_env init
d_scaled = edges.data['d'] / self.cutoff d_scaled = edges.data["d"] / self.cutoff
rbf_env = [f(d_scaled) for f in self.bessel_funcs] rbf_env = [f(d_scaled) for f in self.bessel_funcs]
rbf_env = torch.stack(rbf_env, dim=1) rbf_env = torch.stack(rbf_env, dim=1)
d_cutoff = self.envelope(d_scaled) d_cutoff = self.envelope(d_scaled)
rbf_env = d_cutoff[:, None] * rbf_env rbf_env = d_cutoff[:, None] * rbf_env
return {'m': m, 'rbf_env': rbf_env} return {"m": m, "rbf_env": rbf_env}
def forward(self, g): def forward(self, g):
g.ndata['h'] = self.embedding(g.ndata['Z']) g.ndata["h"] = self.embedding(g.ndata["Z"])
g.apply_edges(self.edge_init) g.apply_edges(self.edge_init)
return g return g
\ No newline at end of file
import torch.nn as nn import torch.nn as nn
class Envelope(nn.Module): class Envelope(nn.Module):
""" """
Envelope function that ensures a smooth cutoff Envelope function that ensures a smooth cutoff
""" """
def __init__(self, exponent): def __init__(self, exponent):
super(Envelope, self).__init__() super(Envelope, self).__init__()
...@@ -11,11 +13,11 @@ class Envelope(nn.Module): ...@@ -11,11 +13,11 @@ class Envelope(nn.Module):
self.a = -(self.p + 1) * (self.p + 2) / 2 self.a = -(self.p + 1) * (self.p + 2) / 2
self.b = self.p * (self.p + 2) self.b = self.p * (self.p + 2)
self.c = -self.p * (self.p + 1) / 2 self.c = -self.p * (self.p + 1) / 2
def forward(self, x): def forward(self, x):
# Envelope function divided by r # Envelope function divided by r
x_p_0 = x.pow(self.p - 1) x_p_0 = x.pow(self.p - 1)
x_p_1 = x_p_0 * x x_p_1 = x_p_0 * x
x_p_2 = x_p_1 * x x_p_2 = x_p_1 * x
env_val = 1 / x + self.a * x_p_0 + self.b * x_p_1 + self.c * x_p_2 env_val = 1 / x + self.a * x_p_0 + self.b * x_p_1 + self.c * x_p_2
return env_val return env_val
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment