"vscode:/vscode.git/clone" did not exist on "d7056c5236483d247cf0f4f149e3c7e72767efab"
Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
from .gnn import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer
\ No newline at end of file
from .gnn import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
......@@ -15,7 +15,7 @@ class Aggregator(nn.Module):
super(Aggregator, self).__init__()
def forward(self, node):
neighbour = node.mailbox['m']
neighbour = node.mailbox["m"]
c = self.aggre(neighbour)
return {"c": c}
......@@ -25,9 +25,9 @@ class Aggregator(nn.Module):
class MeanAggregator(Aggregator):
'''
"""
Mean Aggregator for graphsage
'''
"""
def __init__(self):
super(MeanAggregator, self).__init__()
......@@ -38,17 +38,18 @@ class MeanAggregator(Aggregator):
class MaxPoolAggregator(Aggregator):
'''
"""
Maxpooling aggregator for graphsage
'''
"""
def __init__(self, in_feats, out_feats, activation, bias):
super(MaxPoolAggregator, self).__init__()
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
self.activation = activation
# Xavier initialization of weight
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.linear.weight, gain=nn.init.calculate_gain("relu")
)
def aggre(self, neighbour):
neighbour = self.linear(neighbour)
......@@ -59,9 +60,9 @@ class MaxPoolAggregator(Aggregator):
class LSTMAggregator(Aggregator):
'''
"""
LSTM aggregator for graphsage
'''
"""
def __init__(self, in_feats, hidden_feats):
super(LSTMAggregator, self).__init__()
......@@ -69,31 +70,33 @@ class LSTMAggregator(Aggregator):
self.hidden_dim = hidden_feats
self.hidden = self.init_hidden()
nn.init.xavier_uniform_(self.lstm.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.lstm.weight, gain=nn.init.calculate_gain("relu")
)
def init_hidden(self):
"""
Defaulted to initialite all zero
"""
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim))
return (
torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim),
)
def aggre(self, neighbours):
'''
"""
aggregation function
'''
"""
# N X F
rand_order = torch.randperm(neighbours.size()[1])
neighbours = neighbours[:, rand_order, :]
(lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0],
neighbours.size()[
1],
-1))
(lstm_out, self.hidden) = self.lstm(
neighbours.view(neighbours.size()[0], neighbours.size()[1], -1)
)
return lstm_out[:, -1, :]
def forward(self, node):
neighbour = node.mailbox['m']
neighbour = node.mailbox["m"]
c = self.aggre(neighbour)
return {"c": c}
......@@ -14,8 +14,9 @@ class Bundler(nn.Module):
self.linear = nn.Linear(in_feats * 2, out_feats, bias)
self.activation = activation
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.linear.weight, gain=nn.init.calculate_gain("relu")
)
def concat(self, h, aggre_result):
bundle = torch.cat((h, aggre_result), 1)
......@@ -23,8 +24,8 @@ class Bundler(nn.Module):
return bundle
def forward(self, node):
h = node.data['h']
c = node.data['c']
h = node.data["h"]
c = node.data["c"]
bundle = self.concat(h, c)
bundle = F.normalize(bundle, p=2, dim=1)
if self.activation:
......
import time
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag
from torch.nn import init
import dgl
from .dgl_layers import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer
from .tensorized_layers import *
from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
from .model_utils import batch2tensor
import time
from .tensorized_layers import *
class DiffPool(nn.Module):
......@@ -19,10 +19,23 @@ class DiffPool(nn.Module):
DiffPool Fuse
"""
def __init__(self, input_dim, hidden_dim, embedding_dim,
label_dim, activation, n_layers, dropout,
n_pooling, linkpred, batch_size, aggregator_type,
assign_dim, pool_ratio, cat=False):
def __init__(
self,
input_dim,
hidden_dim,
embedding_dim,
label_dim,
activation,
n_layers,
dropout,
n_pooling,
linkpred,
batch_size,
aggregator_type,
assign_dim,
pool_ratio,
cat=False,
):
super(DiffPool, self).__init__()
self.link_pred = linkpred
self.concat = cat
......@@ -51,7 +64,9 @@ class DiffPool(nn.Module):
activation,
dropout,
aggregator_type,
self.bn))
self.bn,
)
)
for _ in range(n_layers - 2):
self.gc_before_pool.append(
GraphSageLayer(
......@@ -60,14 +75,14 @@ class DiffPool(nn.Module):
activation,
dropout,
aggregator_type,
self.bn))
self.bn,
)
)
self.gc_before_pool.append(
GraphSageLayer(
hidden_dim,
embedding_dim,
None,
dropout,
aggregator_type))
hidden_dim, embedding_dim, None, dropout, aggregator_type
)
)
assign_dims = []
assign_dims.append(self.assign_dim)
......@@ -86,7 +101,8 @@ class DiffPool(nn.Module):
activation,
dropout,
aggregator_type,
self.link_pred)
self.link_pred,
)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
......@@ -102,23 +118,26 @@ class DiffPool(nn.Module):
pool_embedding_dim,
self.assign_dim,
hidden_dim,
self.link_pred))
self.link_pred,
)
)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(
BatchedGraphSAGE(
hidden_dim, hidden_dim))
BatchedGraphSAGE(hidden_dim, hidden_dim)
)
gc_after_per_pool.append(
BatchedGraphSAGE(
hidden_dim, embedding_dim))
BatchedGraphSAGE(hidden_dim, embedding_dim)
)
self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer
if self.concat:
self.pred_input_dim = pool_embedding_dim * \
self.num_aggs * (n_pooling + 1)
self.pred_input_dim = (
pool_embedding_dim * self.num_aggs * (n_pooling + 1)
)
else:
self.pred_input_dim = embedding_dim * self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
......@@ -126,8 +145,9 @@ class DiffPool(nn.Module):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data = init.xavier_uniform_(m.weight.data,
gain=nn.init.calculate_gain('relu'))
m.weight.data = init.xavier_uniform_(
m.weight.data, gain=nn.init.calculate_gain("relu")
)
if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0)
......@@ -161,7 +181,7 @@ class DiffPool(nn.Module):
def forward(self, g):
self.link_pred_loss = []
self.entropy_loss = []
h = g.ndata['feat']
h = g.ndata["feat"]
# node feature for assignment matrix computation is the same as the
# original node feature
h_a = h
......@@ -171,12 +191,12 @@ class DiffPool(nn.Module):
# we use GCN blocks to get an embedding first
g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat)
g.ndata['h'] = g_embedding
g.ndata["h"] = g_embedding
readout = dgl.sum_nodes(g, 'h')
readout = dgl.sum_nodes(g, "h")
out_all.append(readout)
if self.num_aggs == 2:
readout = dgl.max_nodes(g, 'h')
readout = dgl.max_nodes(g, "h")
out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding)
......@@ -184,7 +204,8 @@ class DiffPool(nn.Module):
h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[0], self.concat)
h, adj, self.gc_after_pool[0], self.concat
)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
......@@ -194,7 +215,8 @@ class DiffPool(nn.Module):
for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj)
h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[i + 1], self.concat)
h, adj, self.gc_after_pool[i + 1], self.concat
)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
......@@ -208,10 +230,10 @@ class DiffPool(nn.Module):
return ypred
def loss(self, pred, label):
'''
"""
loss function
'''
#softmax + CE
"""
# softmax + CE
criterion = nn.CrossEntropyLoss()
loss = criterion(pred, label)
for key, value in self.first_diffpool_layer.loss_log.items():
......
......@@ -5,16 +5,19 @@ import torch.nn as nn
class EntropyLoss(nn.Module):
# Return Scalar
def forward(self, adj, anext, s_l):
entropy = (torch.distributions.Categorical(
probs=s_l).entropy()).sum(-1).mean(-1)
entropy = (
(torch.distributions.Categorical(probs=s_l).entropy())
.sum(-1)
.mean(-1)
)
assert not torch.isnan(entropy)
return entropy
class LinkPredLoss(nn.Module):
def forward(self, adj, anext, s_l):
link_pred_loss = (
adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(
dim=(1, 2)
)
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean()
......@@ -22,12 +22,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
return feat, adj
def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
mask_fill_value=-1e32):
'''
def masked_softmax(
matrix, mask, dim=-1, memory_efficient=True, mask_fill_value=-1e32
):
"""
masked_softmax for dgl batch graph
code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
'''
"""
if mask is None:
result = th.nn.functional.softmax(matrix, dim=dim)
else:
......@@ -39,7 +40,8 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else:
masked_matrix = matrix.masked_fill((1 - mask).byte(),
mask_fill_value)
masked_matrix = matrix.masked_fill(
(1 - mask).byte(), mask_fill_value
)
result = th.nn.functional.softmax(masked_matrix, dim=dim)
return result
from .diffpool import BatchedDiffPool
from .graphsage import BatchedGraphSAGE
\ No newline at end of file
from .graphsage import BatchedGraphSAGE
import torch
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch import nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from torch.nn import functional as F
class DiffPoolAssignment(nn.Module):
......
import torch
from torch import nn as nn
from model.loss import EntropyLoss, LinkPredLoss
from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from model.loss import EntropyLoss, LinkPredLoss
from torch import nn as nn
class BatchedDiffPool(nn.Module):
......@@ -25,7 +24,7 @@ class BatchedDiffPool(nn.Module):
z_l = self.embed(x, adj)
s_l = self.assign(x, adj)
if log:
self.log['s'] = s_l.cpu().numpy()
self.log["s"] = s_l.cpu().numpy()
xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
......@@ -33,5 +32,5 @@ class BatchedDiffPool(nn.Module):
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
if log:
self.log['a'] = anext.cpu().numpy()
self.log["a"] = anext.cpu().numpy()
return xnext, anext
......@@ -4,8 +4,9 @@ from torch.nn import functional as F
class BatchedGraphSAGE(nn.Module):
def __init__(self, infeat, outfeat, use_bn=True,
mean=False, add_self=False):
def __init__(
self, infeat, outfeat, use_bn=True, mean=False, add_self=False
):
super().__init__()
self.add_self = add_self
self.use_bn = use_bn
......@@ -13,12 +14,12 @@ class BatchedGraphSAGE(nn.Module):
self.W = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(
self.W.weight,
gain=nn.init.calculate_gain('relu'))
self.W.weight, gain=nn.init.calculate_gain("relu")
)
def forward(self, x, adj):
num_node_per_graph = adj.size(1)
if self.use_bn and not hasattr(self, 'bn'):
if self.use_bn and not hasattr(self, "bn"):
self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)
if self.add_self:
......@@ -37,6 +38,6 @@ class BatchedGraphSAGE(nn.Module):
def __repr__(self):
if self.use_bn:
return 'BN' + super(BatchedGraphSAGE, self).__repr__()
return "BN" + super(BatchedGraphSAGE, self).__repr__()
else:
return super(BatchedGraphSAGE, self).__repr__()
import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse
import os
import random
import time
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from data_utils import pre_process
from model.encoder import DiffPool
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu
from model.encoder import DiffPool
from data_utils import pre_process
global_train_time_per_epoch = []
def arg_parse():
'''
"""
argument parser
'''
parser = argparse.ArgumentParser(description='DiffPool arguments')
parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
"""
parser = argparse.ArgumentParser(description="DiffPool arguments")
parser.add_argument("--dataset", dest="dataset", help="Input Dataset")
parser.add_argument(
"--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio"
)
parser.add_argument(
"--num_pool", dest="num_pool", type=int, help="num_pooling layer"
)
parser.add_argument(
"--no_link_pred",
dest="linkpred",
action="store_false",
help="switch of link prediction object",
)
parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda")
parser.add_argument("--lr", dest="lr", type=float, help="learning rate")
parser.add_argument(
"--clip", dest="clip", type=float, help="gradient clipping"
)
parser.add_argument(
"--batch-size", dest="batch_size", type=int, help="batch size"
)
parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch")
parser.add_argument(
"--train-ratio",
dest="train_ratio",
type=float,
help="ratio of trainning dataset split",
)
parser.add_argument(
'--pool_ratio',
dest='pool_ratio',
"--test-ratio",
dest="test_ratio",
type=float,
help='pooling ratio')
help="ratio of testing dataset split",
)
parser.add_argument(
'--num_pool',
dest='num_pool',
"--num_workers",
dest="n_worker",
type=int,
help='num_pooling layer')
parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
help='switch of link prediction object')
parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
help="number of workers when dataloading",
)
parser.add_argument(
'--clip',
dest='clip',
type=float,
help='gradient clipping')
"--gc-per-block",
dest="gc_per_block",
type=int,
help="number of graph conv layer per block",
)
parser.add_argument(
"--bn",
dest="bn",
action="store_const",
const=True,
default=True,
help="switch for bn",
)
parser.add_argument(
"--dropout", dest="dropout", type=float, help="dropout rate"
)
parser.add_argument(
'--batch-size',
dest='batch_size',
"--bias",
dest="bias",
action="store_const",
const=True,
default=True,
help="switch for bias",
)
parser.add_argument(
"--save_dir",
dest="save_dir",
help="model saving directory: SAVE_DICT/DATASET",
)
parser.add_argument(
"--load_epoch",
dest="load_epoch",
type=int,
help='batch size')
parser.add_argument('--epochs', dest='epoch', type=int,
help='num-of-epoch')
parser.add_argument('--train-ratio', dest='train_ratio', type=float,
help='ratio of trainning dataset split')
parser.add_argument('--test-ratio', dest='test_ratio', type=float,
help='ratio of testing dataset split')
parser.add_argument('--num_workers', dest='n_worker', type=int,
help='number of workers when dataloading')
parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
help='number of graph conv layer per block')
parser.add_argument('--bn', dest='bn', action='store_const', const=True,
default=True, help='switch for bn')
parser.add_argument('--dropout', dest='dropout', type=float,
help='dropout rate')
parser.add_argument('--bias', dest='bias', action='store_const',
const=True, default=True, help='switch for bias')
help="load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH",
)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\
"--data_mode",
dest="data_mode",
help="data\
preprocessing mode: default, id, degree, or one-hot\
vector of degree number', choices=['default', 'id', 'deg',
'deg_num'])
parser.set_defaults(dataset='ENZYMES',
pool_ratio=0.15,
num_pool=1,
cuda=1,
lr=1e-3,
clip=2.0,
batch_size=20,
epoch=4000,
train_ratio=0.7,
test_ratio=0.1,
n_worker=1,
gc_per_block=3,
dropout=0.0,
method='diffpool',
bn=True,
bias=True,
save_dir="./model_param",
load_epoch=-1,
data_mode='default')
vector of degree number",
choices=["default", "id", "deg", "deg_num"],
)
parser.set_defaults(
dataset="ENZYMES",
pool_ratio=0.15,
num_pool=1,
cuda=1,
lr=1e-3,
clip=2.0,
batch_size=20,
epoch=4000,
train_ratio=0.7,
test_ratio=0.1,
n_worker=1,
gc_per_block=3,
dropout=0.0,
method="diffpool",
bn=True,
bias=True,
save_dir="./model_param",
load_epoch=-1,
data_mode="default",
)
return parser.parse_args()
def prepare_data(dataset, prog_args, train=False, pre_process=None):
'''
"""
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
'''
"""
if train:
shuffle = True
else:
......@@ -111,16 +148,18 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
pre_process(dataset, prog_args)
# dataset.set_fold(fold)
return dgl.dataloading.GraphDataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
num_workers=prog_args.n_worker)
return dgl.dataloading.GraphDataLoader(
dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
num_workers=prog_args.n_worker,
)
def graph_classify_task(prog_args):
'''
"""
perform graph classification task
'''
"""
dataset = tu.LegacyTUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset))
......@@ -128,13 +167,17 @@ def graph_classify_task(prog_args):
val_size = int(len(dataset) - train_size - test_size)
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
dataset, (train_size, val_size, test_size))
train_dataloader = prepare_data(dataset_train, prog_args, train=True,
pre_process=pre_process)
val_dataloader = prepare_data(dataset_val, prog_args, train=False,
pre_process=pre_process)
test_dataloader = prepare_data(dataset_test, prog_args, train=False,
pre_process=pre_process)
dataset, (train_size, val_size, test_size)
)
train_dataloader = prepare_data(
dataset_train, prog_args, train=True, pre_process=pre_process
)
val_dataloader = prepare_data(
dataset_val, prog_args, train=False, pre_process=pre_process
)
test_dataloader = prepare_data(
dataset_test, prog_args, train=False, pre_process=pre_process
)
input_dim, label_dim, max_num_node = dataset.statistics()
print("++++++++++STATISTICS ABOUT THE DATASET")
print("dataset feature dimension is", input_dim)
......@@ -157,23 +200,32 @@ def graph_classify_task(prog_args):
# initialize model
# 'diffpool' : diffpool
model = DiffPool(input_dim,
hidden_dim,
embedding_dim,
label_dim,
activation,
prog_args.gc_per_block,
prog_args.dropout,
prog_args.num_pool,
prog_args.linkpred,
prog_args.batch_size,
'meanpool',
assign_dim,
prog_args.pool_ratio)
model = DiffPool(
input_dim,
hidden_dim,
embedding_dim,
label_dim,
activation,
prog_args.gc_per_block,
prog_args.dropout,
prog_args.num_pool,
prog_args.linkpred,
prog_args.batch_size,
"meanpool",
assign_dim,
prog_args.pool_ratio,
)
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(prog_args.load_epoch)))
model.load_state_dict(
torch.load(
prog_args.save_dir
+ "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(prog_args.load_epoch)
)
)
print("model init finished")
print("MODEL:::::::", prog_args.method)
......@@ -181,24 +233,23 @@ def graph_classify_task(prog_args):
model = model.cuda()
logger = train(
train_dataloader,
model,
prog_args,
val_dataset=val_dataloader)
train_dataloader, model, prog_args, val_dataset=val_dataloader
)
result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {:.2f}%".format(result * 100))
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
'''
"""
training function
'''
"""
dir = prog_args.save_dir + "/" + prog_args.dataset
if not os.path.exists(dir):
os.makedirs(dir)
dataloader = dataset
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()), lr=0.001)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
)
early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
if prog_args.cuda > 0:
......@@ -233,34 +284,59 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
optimizer.step()
train_accu = accum_correct / total
print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
train_accu * 100))
print(
"train accuracy for this epoch {} is {:.2f}%".format(
epoch, train_accu * 100
)
)
elapsed_time = time.time() - begin_time
print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
loss.item(), elapsed_time, computation_time))
print(
"loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
loss.item(), elapsed_time, computation_time
)
)
global_train_time_per_epoch.append(elapsed_time)
if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {:.2f}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\
train_accu:
if (
result >= early_stopping_logger["val_acc"]
and result <= train_accu
):
early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc'] * 100))
torch.save(
model.state_dict(),
prog_args.save_dir
+ "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(early_stopping_logger["best_epoch"]),
)
print(
"best epoch is EPOCH {}, val_acc is {:.2f}%".format(
early_stopping_logger["best_epoch"],
early_stopping_logger["val_acc"] * 100,
)
)
torch.cuda.empty_cache()
return early_stopping_logger
def evaluate(dataloader, model, prog_args, logger=None):
'''
"""
evaluate function
'''
"""
if logger is not None and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(logger['best_epoch'])))
model.load_state_dict(
torch.load(
prog_args.save_dir
+ "/"
+ prog_args.dataset
+ "/model.iter-"
+ str(logger["best_epoch"])
)
)
model.eval()
correct_label = 0
with torch.no_grad():
......@@ -280,15 +356,23 @@ def evaluate(dataloader, model, prog_args, logger=None):
def main():
'''
"""
main
'''
"""
prog_args = arg_parse()
print(prog_args)
graph_classify_task(prog_args)
print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ))
print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024)))
print(
"Train time per epoch: {:.4f}".format(
sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
)
)
print(
"Max memory usage: {:.4f}".format(
torch.cuda.max_memory_allocated(0) / (1024 * 1024)
)
)
if __name__ == "__main__":
......
import os
from pathlib import Path
import click
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import click
import numpy as np
import os
from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from modules.initializers import GlorotOrthogonal
from modules.dimenet_pp import DimeNetPP
from modules.initializers import GlorotOrthogonal
from ruamel.yaml import YAML
@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
@click.option('-c', '--convert-cnf', type=click.Path(exists=True), help='Path of convert config yaml.')
@click.option(
"-m",
"--model-cnf",
type=click.Path(exists=True),
help="Path of model config yaml.",
)
@click.option(
"-c",
"--convert-cnf",
type=click.Path(exists=True),
help="Path of convert config yaml.",
)
def main(model_cnf, convert_cnf):
yaml = YAML(typ='safe')
yaml = YAML(typ="safe")
model_cnf = yaml.load(Path(model_cnf))
convert_cnf = yaml.load(Path(convert_cnf))
model_name, model_params, _ = model_cnf['name'], model_cnf['model'], model_cnf['train']
logger.info(f'Model name: {model_name}')
logger.info(f'Model params: {model_params}')
model_name, model_params, _ = (
model_cnf["name"],
model_cnf["model"],
model_cnf["train"],
)
logger.info(f"Model name: {model_name}")
logger.info(f"Model params: {model_params}")
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
model_params['output_init'] = nn.init.zeros_
if model_params["targets"] in ["mu", "homo", "lumo", "gap", "zpve"]:
model_params["output_init"] = nn.init.zeros_
else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal
model_params["output_init"] = GlorotOrthogonal
# model initialization
logger.info('Loading Model')
model = DimeNetPP(emb_size=model_params['emb_size'],
out_emb_size=model_params['out_emb_size'],
int_emb_size=model_params['int_emb_size'],
basis_emb_size=model_params['basis_emb_size'],
num_blocks=model_params['num_blocks'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
extensive=model_params['extensive'],
output_init=model_params['output_init'])
logger.info("Loading Model")
model = DimeNetPP(
emb_size=model_params["emb_size"],
out_emb_size=model_params["out_emb_size"],
int_emb_size=model_params["int_emb_size"],
basis_emb_size=model_params["basis_emb_size"],
num_blocks=model_params["num_blocks"],
num_spherical=model_params["num_spherical"],
num_radial=model_params["num_radial"],
cutoff=model_params["cutoff"],
envelope_exponent=model_params["envelope_exponent"],
num_before_skip=model_params["num_before_skip"],
num_after_skip=model_params["num_after_skip"],
num_dense_output=model_params["num_dense_output"],
num_targets=len(model_params["targets"]),
extensive=model_params["extensive"],
output_init=model_params["output_init"],
)
logger.info(model.state_dict())
tf_path, torch_path = convert_cnf['tf']['ckpt_path'], convert_cnf['torch']['dump_path']
tf_path, torch_path = (
convert_cnf["tf"]["ckpt_path"],
convert_cnf["torch"]["dump_path"],
)
init_vars = tf.train.list_variables(tf_path)
tf_vars_dict = {}
# 147 keys
for name, shape in init_vars:
if name == '_CHECKPOINTABLE_OBJECT_GRAPH':
if name == "_CHECKPOINTABLE_OBJECT_GRAPH":
continue
array = tf.train.load_variable(tf_path, name)
logger.info(f'Loading TF weight {name} with shape {shape}')
logger.info(f"Loading TF weight {name} with shape {shape}")
tf_vars_dict[name] = array
for name, array in tf_vars_dict.items():
name = name.split('/')[:-2]
name = name.split("/")[:-2]
pointer = model
for m_name in name:
if m_name == 'kernel':
pointer = getattr(pointer, 'weight')
elif m_name == 'int_blocks':
pointer = getattr(pointer, 'interaction_blocks')
elif m_name == 'embeddings':
pointer = getattr(pointer, 'embedding')
pointer = getattr(pointer, 'weight')
if m_name == "kernel":
pointer = getattr(pointer, "weight")
elif m_name == "int_blocks":
pointer = getattr(pointer, "interaction_blocks")
elif m_name == "embeddings":
pointer = getattr(pointer, "embedding")
pointer = getattr(pointer, "weight")
else:
pointer = getattr(pointer, m_name)
if name[-1] == 'kernel':
if name[-1] == "kernel":
array = np.transpose(array)
assert array.shape == pointer.shape
logger.info(f'Initialize PyTorch weight {name}')
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
logger.info(f'Save PyTorch model to {torch_path}')
logger.info(f"Save PyTorch model to {torch_path}")
if not os.path.exists(torch_path):
os.makedirs(torch_path)
target = model_params['targets'][0]
torch.save(model.state_dict(), f'{torch_path}/{target}.pt')
target = model_params["targets"][0]
torch.save(model.state_dict(), f"{torch_path}/{target}.pt")
logger.info(model.state_dict())
if __name__ == "__main__":
main()
\ No newline at end of file
main()
import click
import copy
from pathlib import Path
import click
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl
import torch.optim as optim
from logzero import logger
from pathlib import Path
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
from modules.initializers import GlorotOrthogonal
from qm9 import QM9
from ruamel.yaml import YAML
from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import Subset
from sklearn.metrics import mean_absolute_error
from qm9 import QM9
from modules.initializers import GlorotOrthogonal
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=None):
def split_dataset(
dataset, num_train, num_valid, shuffle=False, random_state=None
):
"""Split dataset into training, validation and test set.
Parameters
......@@ -46,6 +50,7 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non
Subsets for training, validation and test.
"""
from itertools import accumulate
num_data = len(dataset)
assert num_train + num_valid < num_data
lengths = [num_train, num_valid, num_data - num_train - num_valid]
......@@ -53,20 +58,26 @@ def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=Non
indices = np.random.RandomState(seed=random_state).permutation(num_data)
else:
indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(accumulate(lengths), lengths)
]
@torch.no_grad()
def ema(ema_model, model, decay):
msd = model.state_dict()
for k, ema_v in ema_model.state_dict().items():
model_v = msd[k].detach()
ema_v.copy_(ema_v * decay + (1. - decay) * model_v)
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def edge_init(edges):
R_src, R_dst = edges.src['R'], edges.dst['R']
R_src, R_dst = edges.src["R"], edges.dst["R"]
dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))
# d: bond length, o: bond orientation
return {'d': dist, 'o': R_src - R_dst}
return {"d": dist, "o": R_src - R_dst}
def _collate_fn(batch):
graphs, line_graphs, labels = map(list, zip(*batch))
......@@ -74,6 +85,7 @@ def _collate_fn(batch):
labels = torch.tensor(labels, dtype=torch.float32)
return g, l_g, labels
def train(device, model, opt, loss_fn, train_loader):
model.train()
epoch_loss = 0
......@@ -93,162 +105,202 @@ def train(device, model, opt, loss_fn, train_loader):
return epoch_loss / num_samples
@torch.no_grad()
def evaluate(device, model, valid_loader):
model.eval()
predictions_all, labels_all = [], []
for g, l_g, labels in valid_loader:
g = g.to(device)
l_g = l_g.to(device)
logits = model(g, l_g)
labels_all.extend(labels)
predictions_all.extend(logits.view(-1,).cpu().numpy())
predictions_all.extend(
logits.view(
-1,
)
.cpu()
.numpy()
)
return np.array(predictions_all), np.array(labels_all)
@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
@click.option(
"-m",
"--model-cnf",
type=click.Path(exists=True),
help="Path of model config yaml.",
)
def main(model_cnf):
yaml = YAML(typ='safe')
yaml = YAML(typ="safe")
model_cnf = yaml.load(Path(model_cnf))
model_name, model_params, train_params, pretrain_params = model_cnf['name'], model_cnf['model'], model_cnf['train'], model_cnf['pretrain']
logger.info(f'Model name: {model_name}')
logger.info(f'Model params: {model_params}')
logger.info(f'Train params: {train_params}')
model_name, model_params, train_params, pretrain_params = (
model_cnf["name"],
model_cnf["model"],
model_cnf["train"],
model_cnf["pretrain"],
)
logger.info(f"Model name: {model_name}")
logger.info(f"Model params: {model_params}")
logger.info(f"Train params: {train_params}")
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
model_params['output_init'] = nn.init.zeros_
if model_params["targets"] in ["mu", "homo", "lumo", "gap", "zpve"]:
model_params["output_init"] = nn.init.zeros_
else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal
model_params["output_init"] = GlorotOrthogonal
logger.info('Loading Data Set')
dataset = QM9(label_keys=model_params['targets'], edge_funcs=[edge_init])
logger.info("Loading Data Set")
dataset = QM9(label_keys=model_params["targets"], edge_funcs=[edge_init])
# data split
train_data, valid_data, test_data = split_dataset(dataset,
num_train=train_params['num_train'],
num_valid=train_params['num_valid'],
shuffle=True,
random_state=train_params['data_seed'])
logger.info(f'Size of Training Set: {len(train_data)}')
logger.info(f'Size of Validation Set: {len(valid_data)}')
logger.info(f'Size of Test Set: {len(test_data)}')
train_data, valid_data, test_data = split_dataset(
dataset,
num_train=train_params["num_train"],
num_valid=train_params["num_valid"],
shuffle=True,
random_state=train_params["data_seed"],
)
logger.info(f"Size of Training Set: {len(train_data)}")
logger.info(f"Size of Validation Set: {len(valid_data)}")
logger.info(f"Size of Test Set: {len(test_data)}")
# data loader
train_loader = DataLoader(train_data,
batch_size=train_params['batch_size'],
shuffle=True,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
valid_loader = DataLoader(valid_data,
batch_size=train_params['batch_size'],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
test_loader = DataLoader(test_data,
batch_size=train_params['batch_size'],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
train_loader = DataLoader(
train_data,
batch_size=train_params["batch_size"],
shuffle=True,
collate_fn=_collate_fn,
num_workers=train_params["num_workers"],
)
valid_loader = DataLoader(
valid_data,
batch_size=train_params["batch_size"],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params["num_workers"],
)
test_loader = DataLoader(
test_data,
batch_size=train_params["batch_size"],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params["num_workers"],
)
# check cuda
gpu = train_params['gpu']
device = f'cuda:{gpu}' if gpu >= 0 and torch.cuda.is_available() else 'cpu'
gpu = train_params["gpu"]
device = f"cuda:{gpu}" if gpu >= 0 and torch.cuda.is_available() else "cpu"
# model initialization
logger.info('Loading Model')
if model_name == 'dimenet':
model = DimeNet(emb_size=model_params['emb_size'],
num_blocks=model_params['num_blocks'],
num_bilinear=model_params['num_bilinear'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
output_init=model_params['output_init']).to(device)
elif model_name == 'dimenet++':
model = DimeNetPP(emb_size=model_params['emb_size'],
out_emb_size=model_params['out_emb_size'],
int_emb_size=model_params['int_emb_size'],
basis_emb_size=model_params['basis_emb_size'],
num_blocks=model_params['num_blocks'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
extensive=model_params['extensive'],
output_init=model_params['output_init']).to(device)
logger.info("Loading Model")
if model_name == "dimenet":
model = DimeNet(
emb_size=model_params["emb_size"],
num_blocks=model_params["num_blocks"],
num_bilinear=model_params["num_bilinear"],
num_spherical=model_params["num_spherical"],
num_radial=model_params["num_radial"],
cutoff=model_params["cutoff"],
envelope_exponent=model_params["envelope_exponent"],
num_before_skip=model_params["num_before_skip"],
num_after_skip=model_params["num_after_skip"],
num_dense_output=model_params["num_dense_output"],
num_targets=len(model_params["targets"]),
output_init=model_params["output_init"],
).to(device)
elif model_name == "dimenet++":
model = DimeNetPP(
emb_size=model_params["emb_size"],
out_emb_size=model_params["out_emb_size"],
int_emb_size=model_params["int_emb_size"],
basis_emb_size=model_params["basis_emb_size"],
num_blocks=model_params["num_blocks"],
num_spherical=model_params["num_spherical"],
num_radial=model_params["num_radial"],
cutoff=model_params["cutoff"],
envelope_exponent=model_params["envelope_exponent"],
num_before_skip=model_params["num_before_skip"],
num_after_skip=model_params["num_after_skip"],
num_dense_output=model_params["num_dense_output"],
num_targets=len(model_params["targets"]),
extensive=model_params["extensive"],
output_init=model_params["output_init"],
).to(device)
else:
raise ValueError(f'Invalid Model Name {model_name}')
raise ValueError(f"Invalid Model Name {model_name}")
if pretrain_params['flag']:
torch_path = pretrain_params['path']
target = model_params['targets'][0]
model.load_state_dict(torch.load(f'{torch_path}/{target}.pt'))
if pretrain_params["flag"]:
torch_path = pretrain_params["path"]
target = model_params["targets"][0]
model.load_state_dict(torch.load(f"{torch_path}/{target}.pt"))
logger.info('Testing with Pretrained model')
logger.info("Testing with Pretrained model")
predictions, labels = evaluate(device, model, test_loader)
test_mae = mean_absolute_error(labels, predictions)
logger.info(f'Test MAE {test_mae:.4f}')
logger.info(f"Test MAE {test_mae:.4f}")
return
# define loss function and optimization
loss_fn = nn.L1Loss()
opt = optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'], amsgrad=True)
scheduler = optim.lr_scheduler.StepLR(opt, train_params['step_size'], gamma=train_params['gamma'])
opt = optim.Adam(
model.parameters(),
lr=train_params["lr"],
weight_decay=train_params["weight_decay"],
amsgrad=True,
)
scheduler = optim.lr_scheduler.StepLR(
opt, train_params["step_size"], gamma=train_params["gamma"]
)
# model training
best_mae = 1e9
no_improvement = 0
# EMA for valid and test
logger.info('EMA Init')
logger.info("EMA Init")
ema_model = copy.deepcopy(model)
for p in ema_model.parameters():
p.requires_grad_(False)
best_model = copy.deepcopy(ema_model)
logger.info('Training')
for i in range(train_params['epochs']):
logger.info("Training")
for i in range(train_params["epochs"]):
train_loss = train(device, model, opt, loss_fn, train_loader)
ema(ema_model, model, train_params['ema_decay'])
if i % train_params['interval'] == 0:
ema(ema_model, model, train_params["ema_decay"])
if i % train_params["interval"] == 0:
predictions, labels = evaluate(device, ema_model, valid_loader)
valid_mae = mean_absolute_error(labels, predictions)
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}')
logger.info(
f"Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}"
)
if valid_mae > best_mae:
no_improvement += 1
if no_improvement == train_params['early_stopping']:
logger.info('Early stop.')
if no_improvement == train_params["early_stopping"]:
logger.info("Early stop.")
break
else:
no_improvement = 0
best_mae = valid_mae
best_model = copy.deepcopy(ema_model)
else:
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f}')
logger.info(f"Epoch {i} | Train Loss {train_loss:.4f}")
scheduler.step()
logger.info('Testing')
logger.info("Testing")
predictions, labels = evaluate(device, best_model, test_loader)
test_mae = mean_absolute_error(labels, predictions)
logger.info('Test MAE {:.4f}'.format(test_mae))
logger.info("Test MAE {:.4f}".format(test_mae))
if __name__ == "__main__":
main()
import torch
def swish(x):
"""
Swish activation function,
from Ramachandran, Zopf, Le 2017. "Searching for Activation Functions"
"""
return x * torch.sigmoid(x)
\ No newline at end of file
return x * torch.sigmoid(x)
import numpy as np
import sympy as sym
from scipy.optimize import brentq
from scipy import special as sp
from scipy.optimize import brentq
def Jn(r, n):
"""
......@@ -19,6 +19,7 @@ def Jn(r, n):
"""
return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) # the same shape as n
def Jn_zeros(n, k):
"""
n: int
......@@ -40,6 +41,7 @@ def Jn_zeros(n, k):
return zerosj
def spherical_bessel_formulas(n):
"""
n: int
......@@ -48,7 +50,7 @@ def spherical_bessel_formulas(n):
n sympy functions
Computes the sympy formulas for the spherical bessel functions up to order n (excluded)
"""
x = sym.symbols('x')
x = sym.symbols("x")
f = [sym.sin(x) / x]
a = sym.sin(x) / x
......@@ -58,6 +60,7 @@ def spherical_bessel_formulas(n):
a = sym.simplify(b)
return f
def bessel_basis(n, k):
"""
n: int
......@@ -79,26 +82,36 @@ def bessel_basis(n, k):
normalizer += [normalizer_tmp]
f = spherical_bessel_formulas(n)
x = sym.symbols('x')
x = sym.symbols("x")
bess_basis = []
for order in range(n):
bess_basis_tmp = []
for i in range(k):
bess_basis_tmp += [sym.simplify(normalizer[order][i] * f[order].subs(x, zeros[order, i] * x))]
bess_basis_tmp += [
sym.simplify(
normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)
)
]
bess_basis += [bess_basis_tmp]
return bess_basis
def sph_harm_prefactor(l, m):
"""
l: int
m: int
res: float
res: float
Computes the constant pre-factor for the spherical harmonic of degree l and order m
input:
l: int, l>=0
m: int, -l<=m<=l
"""
return ((2 * l + 1) * np.math.factorial(l - abs(m)) / (4 * np.pi * np.math.factorial(l + abs(m)))) ** 0.5
return (
(2 * l + 1)
* np.math.factorial(l - abs(m))
/ (4 * np.pi * np.math.factorial(l + abs(m)))
) ** 0.5
def associated_legendre_polynomials(l, zero_m_only=True):
"""
......@@ -106,7 +119,7 @@ def associated_legendre_polynomials(l, zero_m_only=True):
return: l sympy functions
Computes sympy formulas of the associated legendre polynomials up to order l (excluded).
"""
z = sym.symbols('z')
z = sym.symbols("z")
P_l_m = [[0] * (j + 1) for j in range(l)]
P_l_m[0][0] = 1
......@@ -116,18 +129,29 @@ def associated_legendre_polynomials(l, zero_m_only=True):
for j in range(2, l):
P_l_m[j][0] = sym.simplify(
((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j)
((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0])
/ j
)
if not zero_m_only:
for i in range(1, l):
P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
if i + 1 < l:
P_l_m[i + 1][i] = sym.simplify((2 * i + 1) * z * P_l_m[i][i])
P_l_m[i + 1][i] = sym.simplify(
(2 * i + 1) * z * P_l_m[i][i]
)
for j in range(i + 2, l):
P_l_m[j][i] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i))
P_l_m[j][i] = sym.simplify(
(
(2 * j - 1) * z * P_l_m[j - 1][i]
- (i + j - 1) * P_l_m[j - 2][i]
)
/ (j - i)
)
return P_l_m
def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
"""
return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1)
......@@ -138,30 +162,38 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
S_m = [0]
C_m = [1]
for i in range(1, l):
x = sym.symbols('x')
y = sym.symbols('y')
x = sym.symbols("x")
y = sym.symbols("y")
S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
C_m += [x * C_m[i - 1] - y * S_m[i - 1]]
P_l_m = associated_legendre_polynomials(l, zero_m_only)
if spherical_coordinates:
theta = sym.symbols('theta')
z = sym.symbols('z')
theta = sym.symbols("theta")
z = sym.symbols("z")
for i in range(len(P_l_m)):
for j in range(len(P_l_m[i])):
if type(P_l_m[i][j]) != int:
P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
if not zero_m_only:
phi = sym.symbols('phi')
phi = sym.symbols("phi")
for i in range(len(S_m)):
S_m[i] = S_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi))
S_m[i] = (
S_m[i]
.subs(x, sym.sin(theta) * sym.cos(phi))
.subs(y, sym.sin(theta) * sym.sin(phi))
)
for i in range(len(C_m)):
C_m[i] = C_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi))
C_m[i] = (
C_m[i]
.subs(x, sym.sin(theta) * sym.cos(phi))
.subs(y, sym.sin(theta) * sym.sin(phi))
)
Y_func_l_m = [['0'] * (2 * j + 1) for j in range(l)]
Y_func_l_m = [["0"] * (2 * j + 1) for j in range(l)]
for i in range(l):
Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
......@@ -169,9 +201,13 @@ def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
if not zero_m_only:
for i in range(1, l):
for j in range(1, i + 1):
Y_func_l_m[i][j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])
Y_func_l_m[i][j] = sym.simplify(
2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]
)
for i in range(1, l):
for j in range(1, i + 1):
Y_func_l_m[i][-j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])
Y_func_l_m[i][-j] = sym.simplify(
2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]
)
return Y_func_l_m
\ No newline at end of file
return Y_func_l_m
import numpy as np
import torch
import torch.nn as nn
from modules.envelope import Envelope
class BesselBasisLayer(nn.Module):
def __init__(self,
num_radial,
cutoff,
envelope_exponent=5):
def __init__(self, num_radial, cutoff, envelope_exponent=5):
super(BesselBasisLayer, self).__init__()
self.cutoff = cutoff
self.envelope = Envelope(envelope_exponent)
self.frequencies = nn.Parameter(torch.Tensor(num_radial))
......@@ -18,13 +15,15 @@ class BesselBasisLayer(nn.Module):
def reset_params(self):
with torch.no_grad():
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
torch.arange(
1, self.frequencies.numel() + 1, out=self.frequencies
).mul_(np.pi)
self.frequencies.requires_grad_()
def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff
d_scaled = g.edata["d"] / self.cutoff
# Necessary for proper broadcasting behaviour
d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled)
g.edata["rbf"] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.output_block import OutputBlock
from modules.interaction_block import InteractionBlock
from modules.output_block import OutputBlock
from modules.spherical_basis_layer import SphericalBasisLayer
class DimeNet(nn.Module):
"""
......@@ -41,67 +41,86 @@ class DimeNet(nn.Module):
output_init
Initial function in output block
"""
def __init__(self,
emb_size,
num_blocks,
num_bilinear,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
output_init=nn.init.zeros_):
def __init__(
self,
emb_size,
num_blocks,
num_bilinear,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
output_init=nn.init.zeros_,
):
super(DimeNet, self).__init__()
self.num_blocks = num_blocks
self.num_radial = num_radial
# cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
self.rbf_layer = BesselBasisLayer(
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(
num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation)
self.emb_block = EmbeddingBlock(
emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation,
)
# output block
self.output_blocks = nn.ModuleList({
OutputBlock(emb_size=emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
output_init=output_init) for _ in range(num_blocks + 1)
})
self.output_blocks = nn.ModuleList(
{
OutputBlock(
emb_size=emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
output_init=output_init,
)
for _ in range(num_blocks + 1)
}
)
# interaction block
self.interaction_blocks = nn.ModuleList({
InteractionBlock(emb_size=emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_bilinear=num_bilinear,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation) for _ in range(num_blocks)
})
self.interaction_blocks = nn.ModuleList(
{
InteractionBlock(
emb_size=emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_bilinear=num_bilinear,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation,
)
for _ in range(num_blocks)
}
)
def edge_init(self, edges):
# Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o']
R1, R2 = edges.src["o"], edges.dst["o"]
x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1)
......@@ -110,9 +129,9 @@ class DimeNet(nn.Module):
cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
sbf = edges.src['rbf_env'] * cbf # [None, 42]
return {'sbf': sbf}
sbf = edges.src["rbf_env"] * cbf # [None, 42]
return {"sbf": sbf}
def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g)
......@@ -129,5 +148,5 @@ class DimeNet(nn.Module):
for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g)
return P
\ No newline at end of file
return P
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.output_pp_block import OutputPPBlock
from modules.interaction_pp_block import InteractionPPBlock
from modules.output_pp_block import OutputPPBlock
from modules.spherical_basis_layer import SphericalBasisLayer
class DimeNetPP(nn.Module):
"""
......@@ -47,73 +47,92 @@ class DimeNetPP(nn.Module):
output_init
Initial function in output block
"""
def __init__(self,
emb_size,
out_emb_size,
int_emb_size,
basis_emb_size,
num_blocks,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
extensive=True,
output_init=nn.init.zeros_):
def __init__(
self,
emb_size,
out_emb_size,
int_emb_size,
basis_emb_size,
num_blocks,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
extensive=True,
output_init=nn.init.zeros_,
):
super(DimeNetPP, self).__init__()
self.num_blocks = num_blocks
self.num_radial = num_radial
# cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
self.rbf_layer = BesselBasisLayer(
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(
num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent,
)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation)
self.emb_block = EmbeddingBlock(
emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation,
)
# output block
self.output_blocks = nn.ModuleList({
OutputPPBlock(emb_size=emb_size,
out_emb_size=out_emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
extensive=extensive,
output_init=output_init) for _ in range(num_blocks + 1)
})
self.output_blocks = nn.ModuleList(
{
OutputPPBlock(
emb_size=emb_size,
out_emb_size=out_emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
extensive=extensive,
output_init=output_init,
)
for _ in range(num_blocks + 1)
}
)
# interaction block
self.interaction_blocks = nn.ModuleList({
InteractionPPBlock(emb_size=emb_size,
int_emb_size=int_emb_size,
basis_emb_size=basis_emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation) for _ in range(num_blocks)
})
self.interaction_blocks = nn.ModuleList(
{
InteractionPPBlock(
emb_size=emb_size,
int_emb_size=int_emb_size,
basis_emb_size=basis_emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation,
)
for _ in range(num_blocks)
}
)
def edge_init(self, edges):
# Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o']
R1, R2 = edges.src["o"], edges.dst["o"]
x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1)
......@@ -123,9 +142,9 @@ class DimeNetPP(nn.Module):
cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
# Notice: it's dst, not src
sbf = edges.dst['rbf_env'] * cbf # [None, 42]
return {'sbf': sbf}
sbf = edges.dst["rbf_env"] * cbf # [None, 42]
return {"sbf": sbf}
def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g)
......@@ -142,5 +161,5 @@ class DimeNetPP(nn.Module):
for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g)
return P
\ No newline at end of file
return P
import numpy as np
import torch
import torch.nn as nn
from modules.envelope import Envelope
from modules.initializers import GlorotOrthogonal
class EmbeddingBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
bessel_funcs,
cutoff,
envelope_exponent,
num_atom_types=95,
activation=None):
def __init__(
self,
emb_size,
num_radial,
bessel_funcs,
cutoff,
envelope_exponent,
num_atom_types=95,
activation=None,
):
super(EmbeddingBlock, self).__init__()
self.bessel_funcs = bessel_funcs
......@@ -24,35 +26,35 @@ class EmbeddingBlock(nn.Module):
self.dense_rbf = nn.Linear(num_radial, emb_size)
self.dense = nn.Linear(emb_size * 3, emb_size)
self.reset_params()
def reset_params(self):
nn.init.uniform_(self.embedding.weight, a=-np.sqrt(3), b=np.sqrt(3))
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense.weight)
def edge_init(self, edges):
""" msg emb init """
"""msg emb init"""
# m init
rbf = self.dense_rbf(edges.data['rbf'])
rbf = self.dense_rbf(edges.data["rbf"])
if self.activation is not None:
rbf = self.activation(rbf)
m = torch.cat([edges.src['h'], edges.dst['h'], rbf], dim=-1)
m = torch.cat([edges.src["h"], edges.dst["h"], rbf], dim=-1)
m = self.dense(m)
if self.activation is not None:
m = self.activation(m)
# rbf_env init
d_scaled = edges.data['d'] / self.cutoff
d_scaled = edges.data["d"] / self.cutoff
rbf_env = [f(d_scaled) for f in self.bessel_funcs]
rbf_env = torch.stack(rbf_env, dim=1)
d_cutoff = self.envelope(d_scaled)
rbf_env = d_cutoff[:, None] * rbf_env
return {'m': m, 'rbf_env': rbf_env}
return {"m": m, "rbf_env": rbf_env}
def forward(self, g):
g.ndata['h'] = self.embedding(g.ndata['Z'])
g.ndata["h"] = self.embedding(g.ndata["Z"])
g.apply_edges(self.edge_init)
return g
\ No newline at end of file
return g
import torch.nn as nn
class Envelope(nn.Module):
"""
Envelope function that ensures a smooth cutoff
"""
def __init__(self, exponent):
super(Envelope, self).__init__()
......@@ -11,11 +13,11 @@ class Envelope(nn.Module):
self.a = -(self.p + 1) * (self.p + 2) / 2
self.b = self.p * (self.p + 2)
self.c = -self.p * (self.p + 1) / 2
def forward(self, x):
# Envelope function divided by r
x_p_0 = x.pow(self.p - 1)
x_p_1 = x_p_0 * x
x_p_2 = x_p_1 * x
env_val = 1 / x + self.a * x_p_0 + self.b * x_p_1 + self.c * x_p_2
return env_val
\ No newline at end of file
return env_val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment