Unverified Commit 0b47e868 authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Example] Add DimeNet(++) for Molecular Graph Property Prediction (#2706)



* [example] arma

* update

* update

* update

* update

* update

* [example] dimenet

* [docs] update dimenet

* [docs] update tf results

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent c88fca50
......@@ -80,6 +80,7 @@ The folder contains example implementations of selected research papers related
| [Dynamic Graph CNN for Learning on Point Clouds](#dgcnnpoint) | | | | | |
| [Supervised Community Detection with Line Graph Neural Networks](#lgnn) | | | | | |
| [Text Generation from Knowledge Graphs with Graph Transformers](#graphwriter) | | | | | |
| [Directional Message Passing for Molecular Graphs](#dimenet) | | | :heavy_check_mark: | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |
......@@ -101,15 +102,19 @@ The folder contains example implementations of selected research papers related
- Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction
- <a name="gnnfilm"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).
- Example code: [Pytorch](../examples/pytorch/GNN-FiLM)
- Example code: [PyTorch](../examples/pytorch/GNN-FiLM)
- Tags: multi-relational graphs, hypernetworks, GNN architectures
- <a name="gxn"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804).
- Example code: [Pytorch](../examples/pytorch/gxn)
- Example code: [PyTorch](../examples/pytorch/gxn)
- Tags: pooling, graph classification
- <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296).
- Example code: [Pytorch](../examples/pytorch/dagnn)
- Example code: [PyTorch](../examples/pytorch/dagnn)
- Tags: over-smoothing, node classification
- <a name="dimenet"></a> Klicpera et al. Directional Message Passing for Molecular Graphs. [Paper link](https://arxiv.org/abs/2003.03123).
- Example code: [PyTorch](../examples/pytorch/dimenet)
- Tags: molecules, molecular property prediction, quantum chemistry
## 2019
......@@ -180,10 +185,10 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/hgp_sl)
- Tags: graph classification, pooling
- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).
- Example code: [Pytorch](../examples/pytorch/hardgat)
- Example code: [PyTorch](../examples/pytorch/hardgat)
- Tags: node classification, graph attention
- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).
- Example code: [Pytorch](../examples/pytorch/NGCF)
- Example code: [PyTorch](../examples/pytorch/NGCF)
- Tags: Collaborative Filtering, Recommendation, Graph Neural Network
......
# DGL Implementation of DimeNet and DimeNet++
This DGL example implements the GNN model proposed in the paper [Directional Message Passing for Molecular Graphs](https://arxiv.org/abs/2003.03123) and [Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules](https://arxiv.org/abs/2011.14115). For the original implementation, see [here](https://github.com/klicperajo/dimenet).
Contributor: [xnuohz](https://github.com/xnuohz)
* This example implements both DimeNet and DimeNet++.
* The advantages of DimeNet++ over DimeNet
- Fast interactions: replacing bilinear layer with a simple Hadamard priduct
- Embedding hierarchy: using a higher number of embeddings by reducing the embedding size in blocks via down- and up-projection layers
- Other improvements: using less interaction blocks
### Requirements
The codebase is implemented in Python 3.6. For version requirement of packages, see below.
```
click 7.1.2
dgl 0.6.0
logzero 1.6.3
numpy 1.19.5
ruamel.yaml 0.16.12
scikit-learn 0.24.1
scipy 1.5.4
sympy 1.7.1
torch 1.7.0
tqdm 4.56.0
```
### The graph datasets used in this example
The DGL's built-in QM9 dataset. Dataset summary:
* Number of Molecular Graphs: 130,831
* Number of Tasks: 12
### Usage
**Note: DimeNet++ is recommended to use over DimeNet.**
##### Examples
The following commands learn a neural network and predict on the test set.
Training a DimeNet model on QM9 dataset.
```bash
python main.py --model-cnf config/dimenet.yaml
```
Training a DimeNet++ model on QM9 dataset.
```bash
python main.py --model-cnf config/dimenet_pp.yaml
```
For faster experimentation, you should first put the author's [pretrained](https://github.com/klicperajo/dimenet/tree/master/pretrained) folder here, which contains pre-trained TensorFlow models. You can convert a TensorFlow model to a PyTorch model by using the following commands.
```
python convert_tf_ckpt_to_pytorch.py --model-cnf config/dimenet_pp.yaml --convert-cnf config/convert.yaml
```
Then you can set `flag: True` in `dimenet_pp.yaml` and run the above script, DimeNet++ will use the pretrained weights to predict on the test set.
##### Configuration
For more details, please see `config/dimenet.yaml` and `config/dimenet_pp.yaml`
###### Model options
```
// The following paramaters are only used in DimeNet++
out_emb_size int Output embedding size. Default is 256
int_emb_size int Input embedding size. Default is 64
basis_emb_size int Basis embedding size. Default is 8
extensive bool Readout operator for generating a graph-level representation. Default is True
// The following paramater is only used in DimeNet
num_bilinear int Third dimension of the bilinear layer tensor in DimeNet. Default is 8
// The following paramaters are used in both DimeNet and DimeNet++
emb_size int Embedding size used throughout the model. Default is 128
num_blocks int Number of building blocks to be stacked. Default is 6 in DimeNet and 4 in DimeNet++
num_spherical int Number of spherical harmonics. Default is 7
num_radial int Number of radial basis functions. Default is 6
envelope_exponent int Shape of the smooth cutoff. Default is 5
cutoff float Cutoff distance for interatomic interactions. Default is 5.0
num_before_skip int Number of residual layers in interaction block before skip connection. Default is 1
num_after_skip int Number of residual layers in interaction block after skip connection. Default is 2
num_dense_output int Number of dense layers for the output blocks. Default is 3
targets list List of targets to predict. Default is ['mu']
output_init string Initial function name for output layer. Default is 'GlorotOrthogonal'
```
###### Training options
```
num_train int Number of training samples. Default is 110000
num_valid int Number of validation samples. Default is 10000
data_seed int Random seed. Default is 42
lr float Learning rate. Default is 0.001
weight_decay float Weight decay. Default is 0.0001
ema_decay float EMA decay. Default is 0.
batch_size int Batch size. Default is 100
epochs int Training epochs. Default is 300
early_stopping int Patient epochs to wait before early stopping. Default is 20
num_workers int Number of subprocesses to use for data loading. Default is 18
gpu int GPU index. Default is 0, using CUDA:0
interval int Time intervals for model evaluation. Default is 50
step_size int Period of learning rate decay. Default is 100
gamma float Factor of learning rate decay. Default is 0.3
```
### Performance
- Batch size is different
- Linear learning rate warm-up is not used
- Exponential learning rate decay is not used
- Exponential moving average (EMA) is not used
- The values for tasks except mu, alpha, r2, Cv should be x 10^-3
- The author's code didn't provide the pretrained model for gap task
- MAE(DimeNet in Table 1) is from [here](https://arxiv.org/abs/2003.03123)
- MAE(DimeNet++ in Table 2) is from [here](https://arxiv.org/abs/2011.14115)
| Target | mu | alpha | homo | lumo | gap | r2 | zpve | U0 | U | H | G | Cv |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| MAE(DimeNet in Table 1) | 0.0286 | 0.0469 | 27.8 | 19.7 | 34.8 | 0.331 | 1.29 | 8.02 | 7.89 | 8.11 | 8.98 | 0.0249 |
| MAE(DimeNet++ in Table 2) | 0.0297 | 0.0435 | 24.6 | 19.5 | 32.6 | 0.331 | 1.21 | 6.32 | 6.28 | 6.53 | 7.56 | 0.0230 |
| MAE(DimeNet++, TF, pretrain) | 0.0297 | 0.0435 | 0.0246 | 0.0195 | - | 0.3312 | 0.00121 | 0.0063 | 0.00628 | 0.00653 | 0.00756 | 0.0230 |
| MAE(DimeNet++, TF, scratch) | 0.0330 | 0.0447 | 0.0251 | 0.0227 | 0.0486 | 0.3574 | 0.00123 | 0.0065 | 0.00635 | 0.00658 | 0.00747 | 0.0224 |
| MAE(DimeNet++, DGL) | 0.0326 | 0.0537 | 0.0311 | 0.0255 | 0.0490 | 0.4801 | 0.0043 | 0.0141 | 0.0109 | 0.0117 | 0.0150 | 0.0254 |
### Speed
| Model | Original Implementation | DGL Implementation | Improvement |
| :-: | :-: | :-: | :-: |
| DimeNet | 2839 | 1345 | 2.1x |
| DimeNet++ | 624 | 238 | 2.6x |
tf:
ckpt_path: 'pretrained/dimenet_pp/mu'
torch:
dump_path: 'pretrained/converted'
\ No newline at end of file
name: "dimenet"
model:
emb_size: 128
num_blocks: 6
num_bilinear: 8
num_spherical: 7
num_radial: 6
envelope_exponent: 5
cutoff: 5.0
num_before_skip: 1
num_after_skip: 2
num_dense_output: 3
# ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
targets: ['U0']
train:
num_train: 110000
num_valid: 10000
data_seed: 42
lr: 0.001
weight_decay: 0.0001
ema_decay: 0
batch_size: 45
epochs: 300
early_stopping: 20
num_workers: 18
gpu: 0
interval: 50
step_size: 100
gamma: 0.3
pretrain:
flag: False
path: 'pretrained/converted/'
\ No newline at end of file
name: "dimenet++"
model:
emb_size: 128
out_emb_size: 256
int_emb_size: 64
basis_emb_size: 8
num_blocks: 4
num_spherical: 7
num_radial: 6
envelope_exponent: 5
cutoff: 5.0
extensive: True
num_before_skip: 1
num_after_skip: 2
num_dense_output: 3
# ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
targets: ['mu']
train:
num_train: 110000
num_valid: 10000
data_seed: 42
lr: 0.001
weight_decay: 0.0001
ema_decay: 0
batch_size: 100
epochs: 300
early_stopping: 20
num_workers: 18
gpu: 0
interval: 50
step_size: 100
gamma: 0.3
pretrain:
flag: False
path: 'pretrained/converted/'
\ No newline at end of file
import tensorflow as tf
import torch
import torch.nn as nn
import click
import numpy as np
import os
from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from modules.initializers import GlorotOrthogonal
from modules.dimenet_pp import DimeNetPP
@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
@click.option('-c', '--convert-cnf', type=click.Path(exists=True), help='Path of convert config yaml.')
def main(model_cnf, convert_cnf):
yaml = YAML(typ='safe')
model_cnf = yaml.load(Path(model_cnf))
convert_cnf = yaml.load(Path(convert_cnf))
model_name, model_params, _ = model_cnf['name'], model_cnf['model'], model_cnf['train']
logger.info(f'Model name: {model_name}')
logger.info(f'Model params: {model_params}')
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
model_params['output_init'] = nn.init.zeros_
else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal
# model initialization
logger.info('Loading Model')
model = DimeNetPP(emb_size=model_params['emb_size'],
out_emb_size=model_params['out_emb_size'],
int_emb_size=model_params['int_emb_size'],
basis_emb_size=model_params['basis_emb_size'],
num_blocks=model_params['num_blocks'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
extensive=model_params['extensive'],
output_init=model_params['output_init'])
logger.info(model.state_dict())
tf_path, torch_path = convert_cnf['tf']['ckpt_path'], convert_cnf['torch']['dump_path']
init_vars = tf.train.list_variables(tf_path)
tf_vars_dict = {}
# 147 keys
for name, shape in init_vars:
if name == '_CHECKPOINTABLE_OBJECT_GRAPH':
continue
array = tf.train.load_variable(tf_path, name)
logger.info(f'Loading TF weight {name} with shape {shape}')
tf_vars_dict[name] = array
for name, array in tf_vars_dict.items():
name = name.split('/')[:-2]
pointer = model
for m_name in name:
if m_name == 'kernel':
pointer = getattr(pointer, 'weight')
elif m_name == 'int_blocks':
pointer = getattr(pointer, 'interaction_blocks')
elif m_name == 'embeddings':
pointer = getattr(pointer, 'embedding')
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, m_name)
if name[-1] == 'kernel':
array = np.transpose(array)
assert array.shape == pointer.shape
logger.info(f'Initialize PyTorch weight {name}')
pointer.data = torch.from_numpy(array)
logger.info(f'Save PyTorch model to {torch_path}')
if not os.path.exists(torch_path):
os.makedirs(torch_path)
target = model_params['targets'][0]
torch.save(model.state_dict(), f'{torch_path}/{target}.pt')
logger.info(model.state_dict())
if __name__ == "__main__":
main()
\ No newline at end of file
import click
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl
from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from torch.utils.data import DataLoader
from dgl.data.utils import Subset
from sklearn.metrics import mean_absolute_error
from qm9 import QM9
from modules.initializers import GlorotOrthogonal
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=None):
"""Split dataset into training, validation and test set.
Parameters
----------
dataset
We assume that ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
gives the ith datapoint.
num_train : int
Number of training datapoints.
num_valid : int
Number of validation datapoints.
shuffle : bool, optional
By default we perform a consecutive split of the dataset. If True,
we will first randomly shuffle the dataset.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
This can be any integer between 0 and 2^32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default value).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
from itertools import accumulate
num_data = len(dataset)
assert num_train + num_valid < num_data
lengths = [num_train, num_valid, num_data - num_train - num_valid]
if shuffle:
indices = np.random.RandomState(seed=random_state).permutation(num_data)
else:
indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
@torch.no_grad()
def ema(ema_model, model, decay):
msd = model.state_dict()
for k, ema_v in ema_model.state_dict().items():
model_v = msd[k].detach()
ema_v.copy_(ema_v * decay + (1. - decay) * model_v)
def edge_init(edges):
R_src, R_dst = edges.src['R'], edges.dst['R']
dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))
# d: bond length, o: bond orientation
return {'d': dist, 'o': R_src - R_dst}
def _collate_fn(batch):
graphs, line_graphs, labels = map(list, zip(*batch))
g, l_g = dgl.batch(graphs), dgl.batch(line_graphs)
labels = torch.tensor(labels, dtype=torch.float32)
return g, l_g, labels
def train(device, model, opt, loss_fn, train_loader):
model.train()
epoch_loss = 0
num_samples = 0
for g, l_g, labels in train_loader:
g = g.to(device)
l_g = l_g.to(device)
labels = labels.to(device)
logits = model(g, l_g)
loss = loss_fn(logits, labels.view([-1, 1]))
epoch_loss += loss.data.item() * len(labels)
num_samples += len(labels)
opt.zero_grad()
loss.backward()
opt.step()
return epoch_loss / num_samples
@torch.no_grad()
def evaluate(device, model, valid_loader):
model.eval()
predictions_all, labels_all = [], []
for g, l_g, labels in valid_loader:
g = g.to(device)
l_g = l_g.to(device)
logits = model(g, l_g)
labels_all.extend(labels)
predictions_all.extend(logits.view(-1,).cpu().numpy())
return np.array(predictions_all), np.array(labels_all)
@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
def main(model_cnf):
yaml = YAML(typ='safe')
model_cnf = yaml.load(Path(model_cnf))
model_name, model_params, train_params, pretrain_params = model_cnf['name'], model_cnf['model'], model_cnf['train'], model_cnf['pretrain']
logger.info(f'Model name: {model_name}')
logger.info(f'Model params: {model_params}')
logger.info(f'Train params: {train_params}')
if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
model_params['output_init'] = nn.init.zeros_
else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal
logger.info('Loading Data Set')
dataset = QM9(label_keys=model_params['targets'], edge_funcs=[edge_init])
# data split
train_data, valid_data, test_data = split_dataset(dataset,
num_train=train_params['num_train'],
num_valid=train_params['num_valid'],
shuffle=True,
random_state=train_params['data_seed'])
logger.info(f'Size of Training Set: {len(train_data)}')
logger.info(f'Size of Validation Set: {len(valid_data)}')
logger.info(f'Size of Test Set: {len(test_data)}')
# data loader
train_loader = DataLoader(train_data,
batch_size=train_params['batch_size'],
shuffle=True,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
valid_loader = DataLoader(valid_data,
batch_size=train_params['batch_size'],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
test_loader = DataLoader(test_data,
batch_size=train_params['batch_size'],
shuffle=False,
collate_fn=_collate_fn,
num_workers=train_params['num_workers'])
# check cuda
gpu = train_params['gpu']
device = f'cuda:{gpu}' if gpu >= 0 and torch.cuda.is_available() else 'cpu'
# model initialization
logger.info('Loading Model')
if model_name == 'dimenet':
model = DimeNet(emb_size=model_params['emb_size'],
num_blocks=model_params['num_blocks'],
num_bilinear=model_params['num_bilinear'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
output_init=model_params['output_init']).to(device)
elif model_name == 'dimenet++':
model = DimeNetPP(emb_size=model_params['emb_size'],
out_emb_size=model_params['out_emb_size'],
int_emb_size=model_params['int_emb_size'],
basis_emb_size=model_params['basis_emb_size'],
num_blocks=model_params['num_blocks'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
extensive=model_params['extensive'],
output_init=model_params['output_init']).to(device)
else:
raise ValueError(f'Invalid Model Name {model_name}')
if pretrain_params['flag']:
torch_path = pretrain_params['path']
target = model_params['targets'][0]
model.load_state_dict(torch.load(f'{torch_path}/{target}.pt'))
logger.info('Testing with Pretrained model')
predictions, labels = evaluate(device, model, test_loader)
test_mae = mean_absolute_error(labels, predictions)
logger.info(f'Test MAE {test_mae:.4f}')
return
# define loss function and optimization
loss_fn = nn.L1Loss()
opt = optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'], amsgrad=True)
scheduler = optim.lr_scheduler.StepLR(opt, train_params['step_size'], gamma=train_params['gamma'])
# model training
best_mae = 1e9
no_improvement = 0
# EMA for valid and test
logger.info('EMA Init')
ema_model = copy.deepcopy(model)
for p in ema_model.parameters():
p.requires_grad_(False)
best_model = copy.deepcopy(ema_model)
logger.info('Training')
for i in range(train_params['epochs']):
train_loss = train(device, model, opt, loss_fn, train_loader)
ema(ema_model, model, train_params['ema_decay'])
if i % train_params['interval'] == 0:
predictions, labels = evaluate(device, ema_model, valid_loader)
valid_mae = mean_absolute_error(labels, predictions)
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}')
if valid_mae > best_mae:
no_improvement += 1
if no_improvement == train_params['early_stopping']:
logger.info('Early stop.')
break
else:
no_improvement = 0
best_mae = valid_mae
best_model = copy.deepcopy(ema_model)
else:
logger.info(f'Epoch {i} | Train Loss {train_loss:.4f}')
scheduler.step()
logger.info('Testing')
predictions, labels = evaluate(device, best_model, test_loader)
test_mae = mean_absolute_error(labels, predictions)
logger.info('Test MAE {:.4f}'.format(test_mae))
if __name__ == "__main__":
main()
import torch
def swish(x):
"""
Swish activation function,
from Ramachandran, Zopf, Le 2017. "Searching for Activation Functions"
"""
return x * torch.sigmoid(x)
\ No newline at end of file
import numpy as np
import sympy as sym
from scipy.optimize import brentq
from scipy import special as sp
def Jn(r, n):
"""
r: int or list
n: int or list
len(r) == len(n)
return value should be the same shape as the input data
===
example:
r = n = np.array([1, 2, 3, 4])
res = [0.3, 0.1, 0.1, 0.1]
===
numerical spherical bessel functions of order n
"""
return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) # the same shape as n
def Jn_zeros(n, k):
"""
n: int
k: int
res: array of shape [n, k]
Compute the first k zeros of the spherical bessel functions up to order n (excluded)
"""
zerosj = np.zeros((n, k), dtype="float32")
zerosj[0] = np.arange(1, k + 1) * np.pi
points = np.arange(1, k + n) * np.pi
racines = np.zeros(k + n - 1, dtype="float32")
for i in range(1, n):
for j in range(k + n - 1 - i):
foo = brentq(Jn, points[j], points[j + 1], (i,))
racines[j] = foo
points = racines
zerosj[i][:k] = racines[:k]
return zerosj
def spherical_bessel_formulas(n):
"""
n: int
res: array of shape [n,]
n sympy functions
Computes the sympy formulas for the spherical bessel functions up to order n (excluded)
"""
x = sym.symbols('x')
f = [sym.sin(x) / x]
a = sym.sin(x) / x
for i in range(1, n):
b = sym.diff(a, x) / x
f += [sym.simplify(b * (-x) ** i)]
a = sym.simplify(b)
return f
def bessel_basis(n, k):
"""
n: int
k: int
res: [n, k]
n * k sympy functions
Computes the sympy formulas for the normalized and rescaled spherical bessel functions up to
order n (excluded) and maximum frequency k (excluded).
"""
zeros = Jn_zeros(n, k)
normalizer = []
for order in range(n):
normalizer_tmp = []
for i in range(k):
normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2]
normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5
normalizer += [normalizer_tmp]
f = spherical_bessel_formulas(n)
x = sym.symbols('x')
bess_basis = []
for order in range(n):
bess_basis_tmp = []
for i in range(k):
bess_basis_tmp += [sym.simplify(normalizer[order][i] * f[order].subs(x, zeros[order, i] * x))]
bess_basis += [bess_basis_tmp]
return bess_basis
def sph_harm_prefactor(l, m):
"""
l: int
m: int
res: float
Computes the constant pre-factor for the spherical harmonic of degree l and order m
input:
l: int, l>=0
m: int, -l<=m<=l
"""
return ((2 * l + 1) * np.math.factorial(l - abs(m)) / (4 * np.pi * np.math.factorial(l + abs(m)))) ** 0.5
def associated_legendre_polynomials(l, zero_m_only=True):
"""
l: int
return: l sympy functions
Computes sympy formulas of the associated legendre polynomials up to order l (excluded).
"""
z = sym.symbols('z')
P_l_m = [[0] * (j + 1) for j in range(l)]
P_l_m[0][0] = 1
if l > 0:
P_l_m[1][0] = z
for j in range(2, l):
P_l_m[j][0] = sym.simplify(
((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j)
if not zero_m_only:
for i in range(1, l):
P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
if i + 1 < l:
P_l_m[i + 1][i] = sym.simplify((2 * i + 1) * z * P_l_m[i][i])
for j in range(i + 2, l):
P_l_m[j][i] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i))
return P_l_m
def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
"""
return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1)
Computes formula strings of the real part of the spherical harmonics up to order l (excluded).
Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta.
"""
if not zero_m_only:
S_m = [0]
C_m = [1]
for i in range(1, l):
x = sym.symbols('x')
y = sym.symbols('y')
S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
C_m += [x * C_m[i - 1] - y * S_m[i - 1]]
P_l_m = associated_legendre_polynomials(l, zero_m_only)
if spherical_coordinates:
theta = sym.symbols('theta')
z = sym.symbols('z')
for i in range(len(P_l_m)):
for j in range(len(P_l_m[i])):
if type(P_l_m[i][j]) != int:
P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
if not zero_m_only:
phi = sym.symbols('phi')
for i in range(len(S_m)):
S_m[i] = S_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi))
for i in range(len(C_m)):
C_m[i] = C_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi))
Y_func_l_m = [['0'] * (2 * j + 1) for j in range(l)]
for i in range(l):
Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
if not zero_m_only:
for i in range(1, l):
for j in range(1, i + 1):
Y_func_l_m[i][j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])
for i in range(1, l):
for j in range(1, i + 1):
Y_func_l_m[i][-j] = sym.simplify(2 ** 0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])
return Y_func_l_m
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
from modules.envelope import Envelope
class BesselBasisLayer(nn.Module):
def __init__(self,
num_radial,
cutoff,
envelope_exponent=5):
super(BesselBasisLayer, self).__init__()
self.cutoff = cutoff
self.envelope = Envelope(envelope_exponent)
self.frequencies = nn.Parameter(torch.Tensor(num_radial))
self.reset_params()
def reset_params(self):
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff
# Necessary for proper broadcasting behaviour
d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g
\ No newline at end of file
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.output_block import OutputBlock
from modules.interaction_block import InteractionBlock
class DimeNet(nn.Module):
"""
DimeNet model.
Parameters
----------
emb_size
Embedding size used throughout the model
num_blocks
Number of building blocks to be stacked
num_bilinear
Third dimension of the bilinear layer tensor
num_spherical
Number of spherical harmonics
num_radial
Number of radial basis functions
cutoff
Cutoff distance for interatomic interactions
envelope_exponent
Shape of the smooth cutoff
num_before_skip
Number of residual layers in interaction block before skip connection
num_after_skip
Number of residual layers in interaction block after skip connection
num_dense_output
Number of dense layers for the output blocks
num_targets
Number of targets to predict
activation
Activation function
output_init
Initial function in output block
"""
def __init__(self,
emb_size,
num_blocks,
num_bilinear,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
output_init=nn.init.zeros_):
super(DimeNet, self).__init__()
self.num_blocks = num_blocks
self.num_radial = num_radial
# cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation)
# output block
self.output_blocks = nn.ModuleList({
OutputBlock(emb_size=emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
output_init=output_init) for _ in range(num_blocks + 1)
})
# interaction block
self.interaction_blocks = nn.ModuleList({
InteractionBlock(emb_size=emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_bilinear=num_bilinear,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation) for _ in range(num_blocks)
})
def edge_init(self, edges):
# Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o']
x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1)
angle = torch.atan2(y, x)
# Transform via angles
cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
sbf = edges.src['rbf_env'] * cbf # [None, 42]
return {'sbf': sbf}
def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g)
# Embedding block
g = self.emb_block(g)
# Output block
P = self.output_blocks[0](g) # [batch_size, num_targets]
# Prepare sbf feature before the following blocks
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g.apply_edges(self.edge_init)
# Interaction blocks
for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g)
return P
\ No newline at end of file
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.spherical_basis_layer import SphericalBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.output_pp_block import OutputPPBlock
from modules.interaction_pp_block import InteractionPPBlock
class DimeNetPP(nn.Module):
"""
DimeNet++ model.
Parameters
----------
emb_size
Embedding size used for the messages
out_emb_size
Embedding size used for atoms in the output block
int_emb_size
Embedding size used for interaction triplets
basis_emb_size
Embedding size used inside the basis transformation
num_blocks
Number of building blocks to be stacked
num_spherical
Number of spherical harmonics
num_radial
Number of radial basis functions
cutoff
Cutoff distance for interatomic interactions
envelope_exponent
Shape of the smooth cutoff
num_before_skip
Number of residual layers in interaction block before skip connection
num_after_skip
Number of residual layers in interaction block after skip connection
num_dense_output
Number of dense layers for the output blocks
num_targets
Number of targets to predict
activation
Activation function
extensive
Whether the output should be extensive (proportional to the number of atoms)
output_init
Initial function in output block
"""
def __init__(self,
emb_size,
out_emb_size,
int_emb_size,
basis_emb_size,
num_blocks,
num_spherical,
num_radial,
cutoff=5.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_dense_output=3,
num_targets=12,
activation=swish,
extensive=True,
output_init=nn.init.zeros_):
super(DimeNetPP, self).__init__()
self.num_blocks = num_blocks
self.num_radial = num_radial
# cosine basis function expansion layer
self.rbf_layer = BesselBasisLayer(num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
self.sbf_layer = SphericalBasisLayer(num_spherical=num_spherical,
num_radial=num_radial,
cutoff=cutoff,
envelope_exponent=envelope_exponent)
# embedding block
self.emb_block = EmbeddingBlock(emb_size=emb_size,
num_radial=num_radial,
bessel_funcs=self.sbf_layer.get_bessel_funcs(),
cutoff=cutoff,
envelope_exponent=envelope_exponent,
activation=activation)
# output block
self.output_blocks = nn.ModuleList({
OutputPPBlock(emb_size=emb_size,
out_emb_size=out_emb_size,
num_radial=num_radial,
num_dense=num_dense_output,
num_targets=num_targets,
activation=activation,
extensive=extensive,
output_init=output_init) for _ in range(num_blocks + 1)
})
# interaction block
self.interaction_blocks = nn.ModuleList({
InteractionPPBlock(emb_size=emb_size,
int_emb_size=int_emb_size,
basis_emb_size=basis_emb_size,
num_radial=num_radial,
num_spherical=num_spherical,
num_before_skip=num_before_skip,
num_after_skip=num_after_skip,
activation=activation) for _ in range(num_blocks)
})
def edge_init(self, edges):
# Calculate angles k -> j -> i
R1, R2 = edges.src['o'], edges.dst['o']
x = torch.sum(R1 * R2, dim=-1)
y = torch.cross(R1, R2)
y = torch.norm(y, dim=-1)
angle = torch.atan2(y, x)
# Transform via angles
cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
cbf = torch.stack(cbf, dim=1) # [None, 7]
cbf = cbf.repeat_interleave(self.num_radial, dim=1) # [None, 42]
# Notice: it's dst, not src
sbf = edges.dst['rbf_env'] * cbf # [None, 42]
return {'sbf': sbf}
def forward(self, g, l_g):
# add rbf features for each edge in one batch graph, [num_radial,]
g = self.rbf_layer(g)
# Embedding block
g = self.emb_block(g)
# Output block
P = self.output_blocks[0](g) # [batch_size, num_targets]
# Prepare sbf feature before the following blocks
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g.apply_edges(self.edge_init)
# Interaction blocks
for i in range(self.num_blocks):
g = self.interaction_blocks[i](g, l_g)
P += self.output_blocks[i + 1](g)
return P
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
from modules.envelope import Envelope
from modules.initializers import GlorotOrthogonal
class EmbeddingBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
bessel_funcs,
cutoff,
envelope_exponent,
num_atom_types=95,
activation=None):
super(EmbeddingBlock, self).__init__()
self.bessel_funcs = bessel_funcs
self.cutoff = cutoff
self.activation = activation
self.envelope = Envelope(envelope_exponent)
self.embedding = nn.Embedding(num_atom_types, emb_size)
self.dense_rbf = nn.Linear(num_radial, emb_size)
self.dense = nn.Linear(emb_size * 3, emb_size)
self.reset_params()
def reset_params(self):
nn.init.uniform_(self.embedding.weight, a=-np.sqrt(3), b=np.sqrt(3))
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense.weight)
def edge_init(self, edges):
""" msg emb init """
# m init
rbf = self.dense_rbf(edges.data['rbf'])
if self.activation is not None:
rbf = self.activation(rbf)
m = torch.cat([edges.src['h'], edges.dst['h'], rbf], dim=-1)
m = self.dense(m)
if self.activation is not None:
m = self.activation(m)
# rbf_env init
d_scaled = edges.data['d'] / self.cutoff
rbf_env = [f(d_scaled) for f in self.bessel_funcs]
rbf_env = torch.stack(rbf_env, dim=1)
d_cutoff = self.envelope(d_scaled)
rbf_env = d_cutoff[:, None] * rbf_env
return {'m': m, 'rbf_env': rbf_env}
def forward(self, g):
g.ndata['h'] = self.embedding(g.ndata['Z'])
g.apply_edges(self.edge_init)
return g
\ No newline at end of file
import torch.nn as nn
class Envelope(nn.Module):
"""
Envelope function that ensures a smooth cutoff
"""
def __init__(self, exponent):
super(Envelope, self).__init__()
self.p = exponent + 1
self.a = -(self.p + 1) * (self.p + 2) / 2
self.b = self.p * (self.p + 2)
self.c = -self.p * (self.p + 1) / 2
def forward(self, x):
# Envelope function divided by r
x_p_0 = x.pow(self.p - 1)
x_p_1 = x_p_0 * x
x_p_2 = x_p_1 * x
env_val = 1 / x + self.a * x_p_0 + self.b * x_p_1 + self.c * x_p_2
return env_val
\ No newline at end of file
import torch.nn as nn
def GlorotOrthogonal(tensor, scale=2.0):
if tensor is not None:
nn.init.orthogonal_(tensor.data)
scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var()
tensor.data *= scale.sqrt()
\ No newline at end of file
import torch
import torch.nn as nn
import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
num_spherical,
num_bilinear,
num_before_skip,
num_after_skip,
activation=None):
super(InteractionBlock, self).__init__()
self.activation = activation
# Transformations of Bessel and spherical basis representations
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_sbf = nn.Linear(num_radial * num_spherical, num_bilinear, bias=False)
# Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size)
self.dense_kj = nn.Linear(emb_size, emb_size)
# Bilinear layer
bilin_initializer = torch.empty((emb_size, num_bilinear, emb_size)).normal_(mean=0, std=2 / emb_size)
self.W_bilin = nn.Parameter(bilin_initializer)
# Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip)
])
self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip)
])
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense_sbf.weight)
GlorotOrthogonal(self.dense_ji.weight)
GlorotOrthogonal(self.dense_kj.weight)
GlorotOrthogonal(self.final_before_skip.weight)
def edge_transfer(self, edges):
# Transform from Bessel basis to dence vector
rbf = self.dense_rbf(edges.data['rbf'])
# Initial transformation
x_ji = self.dense_ji(edges.data['m'])
x_kj = self.dense_kj(edges.data['m'])
if self.activation is not None:
x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj)
# w: W * e_RBF \bigodot \sigma(W * m + b)
return {'x_kj': x_kj * rbf, 'x_ji': x_ji}
def msg_func(self, edges):
sbf = self.dense_sbf(edges.data['sbf'])
# Apply bilinear layer to interactions and basis function activation
# [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128]
x_kj = torch.einsum("wj,wl,ijl->wi", sbf, edges.src['x_kj'], self.W_bilin)
return {'x_kj': x_kj}
def forward(self, g, l_g):
g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g.update_all(self.msg_func, fn.sum('x_kj', 'm_update'))
for k, v in l_g.ndata.items():
g.edata[k] = v
# Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji']
for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update'])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update'])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
# Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update']
# Transformations after skip connection
for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m'])
return g
\ No newline at end of file
import torch.nn as nn
import dgl
import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionPPBlock(nn.Module):
def __init__(self,
emb_size,
int_emb_size,
basis_emb_size,
num_radial,
num_spherical,
num_before_skip,
num_after_skip,
activation=None):
super(InteractionPPBlock, self).__init__()
self.activation = activation
# Transformations of Bessel and spherical basis representations
self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)
self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False)
self.dense_sbf1 = nn.Linear(num_radial * num_spherical, basis_emb_size, bias=False)
self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)
# Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size)
self.dense_kj = nn.Linear(emb_size, emb_size)
# Embedding projections for interaction triplets
self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False)
self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False)
# Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip)
])
self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip)
])
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf1.weight)
GlorotOrthogonal(self.dense_rbf2.weight)
GlorotOrthogonal(self.dense_sbf1.weight)
GlorotOrthogonal(self.dense_sbf2.weight)
GlorotOrthogonal(self.dense_ji.weight)
nn.init.zeros_(self.dense_ji.bias)
GlorotOrthogonal(self.dense_kj.weight)
nn.init.zeros_(self.dense_kj.bias)
GlorotOrthogonal(self.down_projection.weight)
GlorotOrthogonal(self.up_projection.weight)
def edge_transfer(self, edges):
# Transform from Bessel basis to dense vector
rbf = self.dense_rbf1(edges.data['rbf'])
rbf = self.dense_rbf2(rbf)
# Initial transformation
x_ji = self.dense_ji(edges.data['m'])
x_kj = self.dense_kj(edges.data['m'])
if self.activation is not None:
x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj)
x_kj = self.down_projection(x_kj * rbf)
if self.activation is not None:
x_kj = self.activation(x_kj)
return {'x_kj': x_kj, 'x_ji': x_ji}
def msg_func(self, edges):
sbf = self.dense_sbf1(edges.data['sbf'])
sbf = self.dense_sbf2(sbf)
x_kj = edges.src['x_kj'] * sbf
return {'x_kj': x_kj}
def forward(self, g, l_g):
g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g_reverse = dgl.reverse(l_g, copy_edata=True)
l_g_reverse.update_all(self.msg_func, fn.sum('x_kj', 'm_update'))
g.edata['m_update'] = self.up_projection(l_g_reverse.ndata['m_update'])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
# Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji']
for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update'])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update'])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
# Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update']
# Transformations after skip connection
for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m'])
return g
\ No newline at end of file
import torch.nn as nn
import dgl
import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_):
super(OutputBlock, self).__init__()
self.activation = activation
self.output_init = output_init
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_layers = nn.ModuleList([
nn.Linear(emb_size, emb_size) for _ in range(num_dense)
])
self.dense_final = nn.Linear(emb_size, num_targets, bias=False)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
for layer in self.dense_layers:
GlorotOrthogonal(layer.weight)
self.output_init(self.dense_final.weight)
def forward(self, g):
with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t'])
if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t'])
g.ndata['t'] = self.dense_final(g.ndata['t'])
return dgl.readout_nodes(g, 't')
\ No newline at end of file
import torch.nn as nn
import dgl
import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputPPBlock(nn.Module):
def __init__(self,
emb_size,
out_emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_,
extensive=True):
super(OutputPPBlock, self).__init__()
self.activation = activation
self.output_init = output_init
self.extensive = extensive
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)
self.dense_layers = nn.ModuleList([
nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)
])
self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.up_projection.weight)
for layer in self.dense_layers:
GlorotOrthogonal(layer.weight)
self.output_init(self.dense_final.weight)
def forward(self, g):
with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
g_reverse = dgl.reverse(g, copy_edata=True)
g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
g.ndata['t'] = self.up_projection(g_reverse.ndata['t'])
for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t'])
if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t'])
g.ndata['t'] = self.dense_final(g.ndata['t'])
return dgl.readout_nodes(g, 't', op='sum' if self.extensive else 'mean')
\ No newline at end of file
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class ResidualLayer(nn.Module):
def __init__(self, units, activation=None):
super(ResidualLayer, self).__init__()
self.activation = activation
self.dense_1 = nn.Linear(units, units)
self.dense_2 = nn.Linear(units, units)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_1.weight)
nn.init.zeros_(self.dense_1.bias)
GlorotOrthogonal(self.dense_2.weight)
nn.init.zeros_(self.dense_2.bias)
def forward(self, inputs):
x = self.dense_1(inputs)
if self.activation is not None:
x = self.activation(x)
x = self.dense_2(x)
if self.activation is not None:
x = self.activation(x)
return inputs + x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment