Unverified Commit 011656fd authored by jjhu94's avatar jjhu94 Committed by GitHub
Browse files

[DGL-LifeSci] Weave for Molecular Property Prediction (#1441)



* featurize for weave model

* weave module for molecular graphs

* Update

* completed weave model

* update the whole weave model

* Update

* update atom (node) features

'

* "featurizer"

* Add files via upload

add edge featurizer

* Add files via upload

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
Co-authored-by: default avatarmufeili <mufeili1996@gmail.com>
parent 28117cd9
# Contributing to DGL-LifeSci
Contribution is always welcome. All contributions must go through pull requests
and code review.
Below is a list of community contributors for this project.
Contributors
------------
* [Chengqiang Lu](https://github.com/geekinglcq): Alchemy dataset; MPNN, MGCN and SchNet
* [Jiajing Hu](https://github.com/jjhu94): Weave
...@@ -12,6 +12,8 @@ with graph neural networks. ...@@ -12,6 +12,8 @@ with graph neural networks.
We provide various functionalities, including but not limited to methods for graph construction, We provide various functionalities, including but not limited to methods for graph construction,
featurization, and evaluation, model architectures, training scripts and pre-trained models. featurization, and evaluation, model architectures, training scripts and pre-trained models.
For a list of community contributors, see [here](CONTRIBUTORS.md).
**For a full list of work implemented in DGL-LifeSci, see [here](examples/README.md).** **For a full list of work implemented in DGL-LifeSci, see [here](examples/README.md).**
## Installation ## Installation
......
...@@ -39,6 +39,11 @@ SchNet ...@@ -39,6 +39,11 @@ SchNet
.. automodule:: dgllife.model.gnn.schnet .. automodule:: dgllife.model.gnn.schnet
:members: :members:
Weave
-----
.. automodule:: dgllife.model.gnn.weave
:members:
WLN WLN
--- ---
.. automodule:: dgllife.model.gnn.wln .. automodule:: dgllife.model.gnn.wln
......
...@@ -26,3 +26,8 @@ Weighted Sum and Max Readout ...@@ -26,3 +26,8 @@ Weighted Sum and Max Readout
---------------------------- ----------------------------
.. automodule:: dgllife.model.readout.weighted_sum_and_max .. automodule:: dgllife.model.readout.weighted_sum_and_max
:members: :members:
Weave Readout
-------------
.. automodule:: dgllife.model.readout.weave_readout
:members:
...@@ -49,6 +49,11 @@ SchNet Predictor ...@@ -49,6 +49,11 @@ SchNet Predictor
.. automodule:: dgllife.model.model_zoo.schnet_predictor .. automodule:: dgllife.model.model_zoo.schnet_predictor
:members: :members:
Weave Predictor
```````````````
.. automodule:: dgllife.model.model_zoo.weave_predictor
:members:
Generative Models Generative Models
----------------- -----------------
......
...@@ -12,6 +12,9 @@ We provide various examples across 3 applications -- property prediction, genera ...@@ -12,6 +12,9 @@ We provide various examples across 3 applications -- property prediction, genera
## Property Prediction ## Property Prediction
- Molecular graph convolutions: moving beyond fingerprints (Weave) [[paper]](https://arxiv.org/abs/1603.00856), [[github]](https://github.com/deepchem/deepchem)
- [Weave Predictor with DGL](../python/dgllife/model/model_zoo/weave_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn) - Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](../python/dgllife/model/model_zoo/gcn_predictor.py) - [GCN-Based Predictor with DGL](../python/dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py) - [Example for Molecule Classification](property_prediction/classification.py)
......
...@@ -12,6 +12,7 @@ stress response pathways. Each target yields a binary prediction problem. Molecu ...@@ -12,6 +12,7 @@ stress response pathways. Each target yields a binary prediction problem. Molecu
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method. into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models ### Models
- **Weave** [9]. Weave is one of the pioneering efforts in applying graph neural networks to molecular property prediction.
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural - **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets. convolutions over multiple datasets.
...@@ -22,7 +23,7 @@ explicitly modeling the interactions between adjacent atoms. ...@@ -22,7 +23,7 @@ explicitly modeling the interactions between adjacent atoms.
Use `classification.py` with arguments Use `classification.py` with arguments
``` ```
-m {GCN, GAT}, MODEL, model to use -m {GCN, GAT, Weave}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use -d {Tox21}, DATASET, dataset to use
``` ```
...@@ -49,6 +50,12 @@ a real difference. ...@@ -49,6 +50,12 @@ a real difference.
| ---------------- | --------------------------- | | ---------------- | --------------------------- |
| Pretrained model | 0.853 | | Pretrained model | 0.853 |
#### Weave on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.8074 |
## Regression ## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy. Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
...@@ -177,3 +184,6 @@ Machine Learning*, JMLR. 1263-1272. ...@@ -177,3 +184,6 @@ Machine Learning*, JMLR. 1263-1272.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph [8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*. Attention Mechanism. *Journal of Medicinal Chemistry*.
[9] Kearnes et al. (2016) Molecular graph convolutions: moving beyond fingerprints.
*Journal of Computer-Aided Molecular Design*.
...@@ -9,16 +9,21 @@ from torch.utils.data import DataLoader ...@@ -9,16 +9,21 @@ from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset_for_classification, collate_molgraphs, load_model from utils import set_random_seed, load_dataset_for_classification, collate_molgraphs, load_model
def predict(args, model, bg):
node_feats = bg.ndata.pop(args['node_data_field']).to(args['device'])
if args.get('edge_featurizer', None) is not None:
edge_feats = bg.edata.pop(args['edge_data_field']).to(args['device'])
return model(bg, node_feats, edge_feats)
else:
return model(bg, node_feats)
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer): def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train() model.train()
train_meter = Meter() train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field']) labels, masks = labels.to(args['device']), masks.to(args['device'])
atom_feats, labels, masks = atom_feats.to(args['device']), \ logits = predict(args, model, bg)
labels.to(args['device']), \
masks.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels # Mask non-existing labels
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean() loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -37,9 +42,8 @@ def run_an_eval_epoch(args, model, data_loader): ...@@ -37,9 +42,8 @@ def run_an_eval_epoch(args, model, data_loader):
with torch.no_grad(): with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field']) labels = labels.to(args['device'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device']) logits = predict(args, model, bg)
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, masks) eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args['metric_name'])) return np.mean(eval_meter.compute_metric(args['metric_name']))
...@@ -92,7 +96,7 @@ if __name__ == '__main__': ...@@ -92,7 +96,7 @@ if __name__ == '__main__':
from configure import get_exp_configure from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification') parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT'], parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT', 'Weave'],
help='Model to use') help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'], parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use') help='Dataset to use')
......
from functools import partial from functools import partial
from dgllife.utils.featurizers import CanonicalAtomFeaturizer, BaseAtomFeaturizer, \ # graph construction
BaseBondFeaturizer, ConcatFeaturizer, atom_type_one_hot, atom_degree_one_hot, \ from dgllife.utils import smiles_to_bigraph, smiles_to_complete_graph
atom_formal_charge, atom_num_radical_electrons, atom_hybridization_one_hot, \ # general featurization
atom_total_num_H_one_hot from dgllife.utils import ConcatFeaturizer
# node featurization
from dgllife.utils import CanonicalAtomFeaturizer, BaseAtomFeaturizer, WeaveAtomFeaturizer, \
atom_type_one_hot, atom_degree_one_hot, atom_formal_charge, atom_num_radical_electrons, \
atom_hybridization_one_hot, atom_total_num_H_one_hot
# edge featurization
from dgllife.utils.featurizers import BaseBondFeaturizer, WeaveEdgeFeaturizer
from utils import chirality from utils import chirality
...@@ -12,7 +18,7 @@ GCN_Tox21 = { ...@@ -12,7 +18,7 @@ GCN_Tox21 = {
'batch_size': 128, 'batch_size': 128,
'lr': 1e-3, 'lr': 1e-3,
'num_epochs': 100, 'num_epochs': 100,
'atom_data_field': 'h', 'node_data_field': 'h',
'frac_train': 0.8, 'frac_train': 0.8,
'frac_val': 0.1, 'frac_val': 0.1,
'frac_test': 0.1, 'frac_test': 0.1,
...@@ -20,7 +26,8 @@ GCN_Tox21 = { ...@@ -20,7 +26,8 @@ GCN_Tox21 = {
'gcn_hidden_feats': [64, 64], 'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64, 'classifier_hidden_feats': 64,
'patience': 10, 'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(), 'smiles_to_graph': smiles_to_bigraph,
'node_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score' 'metric_name': 'roc_auc_score'
} }
...@@ -29,7 +36,7 @@ GAT_Tox21 = { ...@@ -29,7 +36,7 @@ GAT_Tox21 = {
'batch_size': 128, 'batch_size': 128,
'lr': 1e-3, 'lr': 1e-3,
'num_epochs': 100, 'num_epochs': 100,
'atom_data_field': 'h', 'node_data_field': 'h',
'frac_train': 0.8, 'frac_train': 0.8,
'frac_val': 0.1, 'frac_val': 0.1,
'frac_test': 0.1, 'frac_test': 0.1,
...@@ -38,7 +45,28 @@ GAT_Tox21 = { ...@@ -38,7 +45,28 @@ GAT_Tox21 = {
'classifier_hidden_feats': 64, 'classifier_hidden_feats': 64,
'num_heads': [4, 4], 'num_heads': [4, 4],
'patience': 10, 'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(), 'smiles_to_graph': smiles_to_bigraph,
'node_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score'
}
Weave_Tox21 = {
'random_seed': 2,
'batch_size': 32,
'lr': 1e-3,
'num_epochs': 100,
'node_data_field': 'h',
'edge_data_field': 'e',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'num_gnn_layers': 2,
'gnn_hidden_feats': 50,
'graph_feats': 128,
'patience': 10,
'smiles_to_graph': partial(smiles_to_complete_graph, add_self_loop=True),
'node_featurizer': WeaveAtomFeaturizer(),
'edge_featurizer': WeaveEdgeFeaturizer(max_distance=2),
'metric_name': 'roc_auc_score' 'metric_name': 'roc_auc_score'
} }
...@@ -103,8 +131,9 @@ AttentiveFP_Aromaticity = { ...@@ -103,8 +131,9 @@ AttentiveFP_Aromaticity = {
'frac_test': 0.1, 'frac_test': 0.1,
'patience': 80, 'patience': 80,
'metric_name': 'rmse', 'metric_name': 'rmse',
'smiles_to_graph': smiles_to_bigraph,
# Follow the atom featurization in the original work # Follow the atom featurization in the original work
'atom_featurizer': BaseAtomFeaturizer( 'node_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([ featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[ partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'], 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
...@@ -117,7 +146,7 @@ AttentiveFP_Aromaticity = { ...@@ -117,7 +146,7 @@ AttentiveFP_Aromaticity = {
], ],
)} )}
), ),
'bond_featurizer': BaseBondFeaturizer({ 'edge_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)] 'he': lambda bond: [0 for _ in range(10)]
}) })
} }
...@@ -125,6 +154,7 @@ AttentiveFP_Aromaticity = { ...@@ -125,6 +154,7 @@ AttentiveFP_Aromaticity = {
experiment_configures = { experiment_configures = {
'GCN_Tox21': GCN_Tox21, 'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21, 'GAT_Tox21': GAT_Tox21,
'Weave_Tox21': Weave_Tox21,
'MPNN_Alchemy': MPNN_Alchemy, 'MPNN_Alchemy': MPNN_Alchemy,
'SchNet_Alchemy': SchNet_Alchemy, 'SchNet_Alchemy': SchNet_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy, 'MGCN_Alchemy': MGCN_Alchemy,
......
...@@ -4,7 +4,6 @@ import random ...@@ -4,7 +4,6 @@ import random
import torch import torch
from dgllife.utils.featurizers import one_hot_encoding from dgllife.utils.featurizers import one_hot_encoding
from dgllife.utils.mol_to_graph import smiles_to_bigraph
from dgllife.utils.splitters import RandomSplitter from dgllife.utils.splitters import RandomSplitter
def set_random_seed(seed=0): def set_random_seed(seed=0):
...@@ -20,6 +19,7 @@ def set_random_seed(seed=0): ...@@ -20,6 +19,7 @@ def set_random_seed(seed=0):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def load_dataset_for_classification(args): def load_dataset_for_classification(args):
"""Load dataset for classification tasks. """Load dataset for classification tasks.
Parameters Parameters
...@@ -40,13 +40,16 @@ def load_dataset_for_classification(args): ...@@ -40,13 +40,16 @@ def load_dataset_for_classification(args):
assert args['dataset'] in ['Tox21'] assert args['dataset'] in ['Tox21']
if args['dataset'] == 'Tox21': if args['dataset'] == 'Tox21':
from dgllife.data import Tox21 from dgllife.data import Tox21
dataset = Tox21(smiles_to_bigraph, args['atom_featurizer']) dataset = Tox21(smiles_to_graph=args['smiles_to_graph'],
node_featurizer=args.get('node_featurizer', None),
edge_featurizer=args.get('edge_featurizer', None))
train_set, val_set, test_set = RandomSplitter.train_val_test_split( train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed']) frac_test=args['frac_test'], random_state=args['random_seed'])
return dataset, train_set, val_set, test_set return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args): def load_dataset_for_regression(args):
"""Load dataset for regression tasks. """Load dataset for regression tasks.
Parameters Parameters
...@@ -72,15 +75,16 @@ def load_dataset_for_regression(args): ...@@ -72,15 +75,16 @@ def load_dataset_for_regression(args):
if args['dataset'] == 'Aromaticity': if args['dataset'] == 'Aromaticity':
from dgllife.data import PubChemBioAssayAromaticity from dgllife.data import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(smiles_to_bigraph, dataset = PubChemBioAssayAromaticity(smiles_to_graph=args['smiles_to_graph'],
args['atom_featurizer'], node_featurizer=args.get('node_featurizer', None),
args['bond_featurizer']) edge_featurizer=args.get('edge_featurizer', None))
train_set, val_set, test_set = RandomSplitter.train_val_test_split( train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed']) frac_test=args['frac_test'], random_state=args['random_seed'])
return train_set, val_set, test_set return train_set, val_set, test_set
def collate_molgraphs(data): def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader. """Batching a list of datapoints for dataloader.
...@@ -124,26 +128,36 @@ def collate_molgraphs(data): ...@@ -124,26 +128,36 @@ def collate_molgraphs(data):
masks = torch.stack(masks, dim=0) masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks return smiles, bg, labels, masks
def load_model(args): def load_model(args):
if args['model'] == 'GCN': if args['model'] == 'GCN':
from dgllife.model import GCNPredictor from dgllife.model import GCNPredictor
model = GCNPredictor(in_feats=args['in_feats'], model = GCNPredictor(in_feats=args['node_featurizer'].feat_size(),
hidden_feats=args['gcn_hidden_feats'], hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks']) n_tasks=args['n_tasks'])
if args['model'] == 'GAT': if args['model'] == 'GAT':
from dgllife.model import GATPredictor from dgllife.model import GATPredictor
model = GATPredictor(in_feats=args['in_feats'], model = GATPredictor(in_feats=args['node_featurizer'].feat_size(),
hidden_feats=args['gat_hidden_feats'], hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'], num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks']) n_tasks=args['n_tasks'])
if args['model'] == 'Weave':
from dgllife.model import WeavePredictor
model = WeavePredictor(node_in_feats=args['node_featurizer'].feat_size(),
edge_in_feats=args['edge_featurizer'].feat_size(),
num_gnn_layers=args['num_gnn_layers'],
gnn_hidden_feats=args['gnn_hidden_feats'],
graph_feats=args['graph_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'AttentiveFP': if args['model'] == 'AttentiveFP':
from dgllife.model import AttentiveFPPredictor from dgllife.model import AttentiveFPPredictor
model = AttentiveFPPredictor(node_feat_size=args['node_feat_size'], model = AttentiveFPPredictor(node_feat_size=args['node_featurizer'].feat_size(),
edge_feat_size=args['edge_feat_size'], edge_feat_size=args['edge_featurizer'].feat_size(),
num_layers=args['num_layers'], num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'], num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'], graph_feat_size=args['graph_feat_size'],
...@@ -174,6 +188,7 @@ def load_model(args): ...@@ -174,6 +188,7 @@ def load_model(args):
return model return model
def chirality(atom): def chirality(atom):
try: try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \ return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
......
...@@ -6,3 +6,4 @@ from .mgcn import * ...@@ -6,3 +6,4 @@ from .mgcn import *
from .mpnn import * from .mpnn import *
from .schnet import * from .schnet import *
from .wln import * from .wln import *
from .weave import *
...@@ -59,9 +59,9 @@ class MPNNGNN(nn.Module): ...@@ -59,9 +59,9 @@ class MPNNGNN(nn.Module):
g : DGLGraph g : DGLGraph
DGLGraph for a batch of graphs. DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats) node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats) edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. Input edge features. E for the number of edges in the batch of graphs.
Returns Returns
------- -------
......
"""Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['WeaveGNN']
# pylint: disable=W0221, E1101
class WeaveLayer(nn.Module):
r"""Single Weave layer from `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_node_hidden_feats : int
Size for the hidden node representations in updating node representations.
Default to 50.
edge_node_hidden_feats : int
Size for the hidden edge representations in updating node representations.
Default to 50.
node_out_feats : int
Size for the output node representations. Default to 50.
node_edge_hidden_feats : int
Size for the hidden node representations in updating edge representations.
Default to 50.
edge_edge_hidden_feats : int
Size for the hidden edge representations in updating edge representations.
Default to 50.
edge_out_feats : int
Size for the output edge representations. Default to 50.
activation : callable
Activation function to apply. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_node_hidden_feats=50,
edge_node_hidden_feats=50,
node_out_feats=50,
node_edge_hidden_feats=50,
edge_edge_hidden_feats=50,
edge_out_feats=50,
activation=F.relu):
super(WeaveLayer, self).__init__()
self.activation = activation
# Layers for updating node representations
self.node_to_node = nn.Linear(node_in_feats, node_node_hidden_feats)
self.edge_to_node = nn.Linear(edge_in_feats, edge_node_hidden_feats)
self.update_node = nn.Linear(
node_node_hidden_feats + edge_node_hidden_feats, node_out_feats)
# Layers for updating edge representations
self.left_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.right_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.edge_to_edge = nn.Linear(edge_in_feats, edge_edge_hidden_feats)
self.update_edge = nn.Linear(
2 * node_edge_hidden_feats + edge_edge_hidden_feats, edge_out_feats)
def forward(self, g, node_feats, edge_feats, node_only=False):
r"""Update node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to update node representations only. If False, edge representations
will be updated as well. Default to False.
Returns
-------
new_node_feats : float32 tensor of shape (V, node_out_feats)
Updated node representations.
new_edge_feats : float32 tensor of shape (E, edge_out_feats)
Updated edge representations.
"""
g = g.local_var()
# Update node features
node_node_feats = self.activation(self.node_to_node(node_feats))
g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
edge_node_feats = g.ndata.pop('e2n')
new_node_feats = self.activation(self.update_node(
torch.cat([node_node_feats, edge_node_feats], dim=1)))
if node_only:
return new_node_feats
# Update edge features
g.ndata['left_hv'] = self.left_node_to_edge(node_feats)
g.ndata['right_hv'] = self.right_node_to_edge(node_feats)
g.apply_edges(fn.u_add_v('left_hv', 'right_hv', 'first'))
g.apply_edges(fn.u_add_v('right_hv', 'left_hv', 'second'))
first_edge_feats = self.activation(g.edata.pop('first'))
second_edge_feats = self.activation(g.edata.pop('second'))
third_edge_feats = self.activation(self.edge_to_edge(edge_feats))
new_edge_feats = self.activation(self.update_edge(
torch.cat([first_edge_feats, second_edge_feats, third_edge_feats], dim=1)))
return new_node_feats, new_edge_feats
class WeaveGNN(nn.Module):
r"""The component of Weave for updating node and edge representations.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_layers : int
Number of Weave layers to use, which is equivalent to the times of message passing.
Default to 2.
hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
activation : callable
Activation function to be used. It cannot be None. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
num_layers=2,
hidden_feats=50,
activation=F.relu):
super(WeaveGNN, self).__init__()
self.gnn_layers = nn.ModuleList()
for i in range(num_layers):
if i == 0:
self.gnn_layers.append(WeaveLayer(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
else:
self.gnn_layers.append(WeaveLayer(node_in_feats=hidden_feats,
edge_in_feats=hidden_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
def forward(self, g, node_feats, edge_feats, node_only=True):
"""Updates node representations (and edge representations).
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to return updated node representations only or to return both
node and edge representations. Default to True.
Returns
-------
float32 tensor of shape (V, gnn_hidden_feats)
Updated node representations.
float32 tensor of shape (E, gnn_hidden_feats), optional
This is returned only when ``node_only==False``. Updated edge representations.
"""
for i in range(len(self.gnn_layers) - 1):
node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
return self.gnn_layers[-1](g, node_feats, edge_feats, node_only)
...@@ -10,3 +10,4 @@ from .mgcn_predictor import * ...@@ -10,3 +10,4 @@ from .mgcn_predictor import *
from .mpnn_predictor import * from .mpnn_predictor import *
from .acnn import * from .acnn import *
from .wln_reaction_center import * from .wln_reaction_center import *
from .weave_predictor import *
"""Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from ..gnn import WeaveGNN
from ..readout import WeaveGather
__all__ = ['WeavePredictor']
# pylint: disable=W0221
class WeavePredictor(nn.Module):
r"""Weave for regression and classification on graphs.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_gnn_layers : int
Number of GNN (Weave) layers to use. Default to 2.
gnn_hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
gnn_activation : callable
Activation function to be used in GNN (Weave) layers. Default to ReLU.
graph_feats : int
Size for the hidden graph representations. Default to 50.
gaussian_expand : bool
Whether to expand each dimension of node features by gaussian histogram in
computing graph representations. Default to True.
gaussian_memberships : list of 2-tuples
For each tuple, the first and second element separately specifies the mean
and std for constructing a normal distribution. This argument comes into
effect only when ``gaussian_expand==True``. By default, we set this to be
``[(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]``.
readout_activation : callable
Activation function to be used in computing graph representations out of
node representations. Default to Tanh.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
num_gnn_layers=2,
gnn_hidden_feats=50,
gnn_activation=F.relu,
graph_feats=128,
gaussian_expand=True,
gaussian_memberships=None,
readout_activation=nn.Tanh(),
n_tasks=1):
super(WeavePredictor, self).__init__()
self.gnn = WeaveGNN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
num_layers=num_gnn_layers,
hidden_feats=gnn_hidden_feats,
activation=gnn_activation)
self.node_to_graph = nn.Sequential(
nn.Linear(gnn_hidden_feats, graph_feats),
readout_activation,
nn.BatchNorm1d(graph_feats)
)
self.readout = WeaveGather(node_in_feats=graph_feats,
gaussian_expand=gaussian_expand,
gaussian_memberships=gaussian_memberships,
activation=readout_activation)
self.predict = nn.Linear(graph_feats, n_tasks)
def forward(self, g, node_feats, edge_feats):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats = self.gnn(g, node_feats, edge_feats, node_only=True)
node_feats = self.node_to_graph(node_feats)
g_feats = self.readout(g, node_feats)
return self.predict(g_feats)
...@@ -8,13 +8,14 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc ...@@ -8,13 +8,14 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc
from rdkit import Chem from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \ from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter WLNReactionCenter, WeavePredictor
__all__ = ['load_pretrained'] __all__ = ['load_pretrained']
URL = { URL = {
'GCN_Tox21': 'dgllife/pre_trained/gcn_tox21.pth', 'GCN_Tox21': 'dgllife/pre_trained/gcn_tox21.pth',
'GAT_Tox21': 'dgllife/pre_trained/gat_tox21.pth', 'GAT_Tox21': 'dgllife/pre_trained/gat_tox21.pth',
'Weave_Tox21': 'dgllife/pre_trained/weave_tox21.pth',
'AttentiveFP_Aromaticity': 'dgllife/pre_trained/attentivefp_aromaticity.pth', 'AttentiveFP_Aromaticity': 'dgllife/pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth', 'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
...@@ -70,6 +71,7 @@ def load_pretrained(model_name, log=True): ...@@ -70,6 +71,7 @@ def load_pretrained(model_name, log=True):
* ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21 * ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21 * ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'Weave_Tox21'``: A Weave model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of * ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
aromatic atoms on a subset of Pubmed aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical * ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
...@@ -108,6 +110,14 @@ def load_pretrained(model_name, log=True): ...@@ -108,6 +110,14 @@ def load_pretrained(model_name, log=True):
classifier_hidden_feats=64, classifier_hidden_feats=64,
n_tasks=12) n_tasks=12)
elif model_name == 'Weave_Tox21':
model = WeavePredictor(node_in_feats=27,
edge_in_feats=7,
num_gnn_layers=2,
gnn_hidden_feats=50,
graph_feats=128,
n_tasks=12)
elif model_name == 'AttentiveFP_Aromaticity': elif model_name == 'AttentiveFP_Aromaticity':
model = AttentiveFPPredictor(node_feat_size=39, model = AttentiveFPPredictor(node_feat_size=39,
edge_feat_size=10, edge_feat_size=10,
......
...@@ -5,3 +5,4 @@ out of node and edge representations. ...@@ -5,3 +5,4 @@ out of node and edge representations.
from .attentivefp_readout import * from .attentivefp_readout import *
from .weighted_sum_and_max import * from .weighted_sum_and_max import *
from .mlp_readout import * from .mlp_readout import *
from .weave_readout import *
"""Readout for Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
from torch.distributions import Normal
__all__ = ['WeaveGather']
# pylint: disable=W0221, E1101, E1102
class WeaveGather(nn.Module):
r"""Readout in Weave
Parameters
----------
node_in_feats : int
Size for the input node features.
gaussian_expand : bool
Whether to expand each dimension of node features by gaussian histogram.
Default to True.
gaussian_memberships : list of 2-tuples
For each tuple, the first and second element separately specifies the mean
and std for constructing a normal distribution. This argument comes into
effect only when ``gaussian_expand==True``. By default, we set this to be
``[(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]``.
activation : callable
Activation function to apply. Default to tanh.
"""
def __init__(self,
node_in_feats,
gaussian_expand=True,
gaussian_memberships=None,
activation=nn.Tanh()):
super(WeaveGather, self).__init__()
self.gaussian_expand = gaussian_expand
if gaussian_expand:
if gaussian_memberships is None:
gaussian_memberships = [
(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]
means, stds = map(list, zip(*gaussian_memberships))
self.means = nn.ParameterList([
nn.Parameter(torch.tensor(value), requires_grad=False)
for value in means
])
self.stds = nn.ParameterList([
nn.Parameter(torch.tensor(value), requires_grad=False)
for value in stds
])
self.to_out = nn.Linear(node_in_feats * len(self.means), node_in_feats)
self.activation = activation
def gaussian_histogram(self, node_feats):
r"""Constructs a gaussian histogram to capture the distribution of features
Parameters
----------
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
float32 tensor of shape (V, node_in_feats * len(self.means))
Updated node representations
"""
gaussian_dists = [Normal(self.means[i], self.stds[i])
for i in range(len(self.means))]
max_log_probs = [gaussian_dists[i].log_prob(self.means[i])
for i in range(len(self.means))]
# Normalize the probabilities by the maximum point-wise probabilities,
# whose results will be in range [0, 1]. Note that division of probabilities
# is equivalent to subtraction of log probabilities and the latter one is cheaper.
log_probs = [gaussian_dists[i].log_prob(node_feats) - max_log_probs[i]
for i in range(len(self.means))]
probs = torch.stack(log_probs, dim=2).exp() # (V, node_in_feats, len(self.means))
# Add a bias to avoid numerical issues in division
probs = probs + 1e-7
# Normalize the probabilities across all Gaussian distributions
probs = probs / probs.sum(2, keepdim=True)
return probs.reshape(node_feats.shape[0],
node_feats.shape[1] * len(self.means))
def forward(self, g, node_feats):
r"""Computes graph representations out of node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
g_feats : float32 tensor of shape (G, node_in_feats)
Output graph representations. G for the number of graphs in the batch.
"""
if self.gaussian_expand:
node_feats = self.gaussian_histogram(node_feats)
with g.local_scope():
g.ndata['h'] = node_feats
g_feats = dgl.sum_nodes(g, 'h')
if self.gaussian_expand:
g_feats = self.to_out(g_feats)
if self.activation is not None:
g_feats = self.activation(g_feats)
return g_feats
"""Node and edge featurization for molecular graphs.""" """Node and edge featurization for molecular graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import itertools import itertools
import os.path as osp
from collections import defaultdict from collections import defaultdict
import dgl.backend as F from functools import partial
import numpy as np from rdkit import Chem, RDConfig
from rdkit.Chem import AllChem, ChemicalFeatures
from rdkit import Chem import numpy as np
import torch
import dgl.backend as F
__all__ = ['one_hot_encoding', __all__ = ['one_hot_encoding',
'atom_type_one_hot', 'atom_type_one_hot',
...@@ -35,6 +40,7 @@ __all__ = ['one_hot_encoding', ...@@ -35,6 +40,7 @@ __all__ = ['one_hot_encoding',
'ConcatFeaturizer', 'ConcatFeaturizer',
'BaseAtomFeaturizer', 'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer', 'CanonicalAtomFeaturizer',
'WeaveAtomFeaturizer',
'bond_type_one_hot', 'bond_type_one_hot',
'bond_is_conjugated_one_hot', 'bond_is_conjugated_one_hot',
'bond_is_conjugated', 'bond_is_conjugated',
...@@ -42,7 +48,8 @@ __all__ = ['one_hot_encoding', ...@@ -42,7 +48,8 @@ __all__ = ['one_hot_encoding',
'bond_is_in_ring', 'bond_is_in_ring',
'bond_stereo_one_hot', 'bond_stereo_one_hot',
'BaseBondFeaturizer', 'BaseBondFeaturizer',
'CanonicalBondFeaturizer'] 'CanonicalBondFeaturizer',
'WeaveEdgeFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False): def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding. """One-hot encoding.
...@@ -482,6 +489,30 @@ def atom_formal_charge(atom): ...@@ -482,6 +489,30 @@ def atom_formal_charge(atom):
""" """
return [atom.GetFormalCharge()] return [atom.GetFormalCharge()]
def atom_partial_charge(atom):
"""Get Gasteiger partial charge for an atom.
For using this function, you must have called ``AllChem.ComputeGasteigerCharges(mol)``
to compute Gasteiger charges.
Occasionally, we can get nan or infinity Gasteiger charges, in which case we will set
the result to be 0.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one float only.
"""
gasteiger_charge = atom.GetProp('_GasteigerCharge')
if gasteiger_charge in ['-nan', 'nan', '-inf', 'inf']:
gasteiger_charge = 0
return [float(gasteiger_charge)]
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False): def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom. """One hot encoding for the number of radical electrons of an atom.
...@@ -633,6 +664,11 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -633,6 +664,11 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``, ``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``. ``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
Returns
-------
list
List containing one bool only.
See Also See Also
-------- --------
one_hot_encoding one_hot_encoding
...@@ -740,14 +776,26 @@ class BaseAtomFeaturizer(object): ...@@ -740,14 +776,26 @@ class BaseAtomFeaturizer(object):
feat_sizes = dict() feat_sizes = dict()
self._feat_sizes = feat_sizes self._feat_sizes = feat_sizes
def feat_size(self, feat_name): def feat_size(self, feat_name=None):
"""Get the feature size for ``feat_name``. """Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns Returns
------- -------
int int
Feature size for the feature with name ``feat_name``. Feature size for the feature with name ``feat_name``. Default to None.
""" """
if feat_name is None:
assert len(self.featurizer_funcs) == 1, \
'feat_name should be provided if there are more than one features'
feat_name = list(self.featurizer_funcs.keys())[0]
if feat_name not in self.featurizer_funcs: if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format( return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name)) list(self.featurizer_funcs.keys()), feat_name))
...@@ -818,7 +866,7 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer): ...@@ -818,7 +866,7 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
Parameters Parameters
---------- ----------
atom_data_field : str atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'. Name for storing atom features in DGLGraphs, default to 'h'.
Examples Examples
-------- --------
...@@ -865,6 +913,162 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer): ...@@ -865,6 +913,162 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
atom_total_num_H_one_hot] atom_total_num_H_one_hot]
)}) )})
class WeaveAtomFeaturizer(object):
"""Atom featurizer in Weave.
The atom featurization performed in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__, which considers:
* atom types
* chirality
* formal charge
* partial charge
* aromatic atom
* hybridization
* hydrogen bond donor
* hydrogen bond acceptor
* the number of rings the atom belongs to for ring size between 3 and 8
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to 'h'.
atom_types : list of str or None
Atom types to consider for one-hot encoding. If None, we will use a default
choice of ``'H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'``.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``.
hybridization_types : list of Chem.rdchem.HybridizationType or None
Atom hybridization types to consider for one-hot encoding. If None, we will use a
default choice of ``Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3``.
"""
def __init__(self, atom_data_field='h', atom_types=None, chiral_types=None,
hybridization_types=None):
super(WeaveAtomFeaturizer, self).__init__()
self._atom_data_field = atom_data_field
if atom_types is None:
atom_types = ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I']
self._atom_types = atom_types
if chiral_types is None:
chiral_types = [Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW]
self._chiral_types = chiral_types
if hybridization_types is None:
hybridization_types = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3]
self._hybridization_types = hybridization_types
self._featurizer = ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=atom_types, encode_unknown=True),
partial(atom_chiral_tag_one_hot, allowable_set=chiral_types),
atom_formal_charge, atom_partial_charge, atom_is_aromatic,
partial(atom_hybridization_one_hot, allowable_set=hybridization_types)
])
def feat_size(self):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol = Chem.MolFromSmiles('C')
feats = self(mol)[self._atom_data_field]
return feats.shape[-1]
def get_donor_acceptor_info(self, mol_feats):
"""Bookkeep whether an atom is donor/acceptor for hydrogen bonds.
Parameters
----------
mol_feats : tuple of rdkit.Chem.rdMolChemicalFeatures.MolChemicalFeature
Features for molecules.
Returns
-------
is_donor : dict
Mapping atom ids to binary values indicating whether atoms
are donors for hydrogen bonds
is_acceptor : dict
Mapping atom ids to binary values indicating whether atoms
are acceptors for hydrogen bonds
"""
is_donor = defaultdict(bool)
is_acceptor = defaultdict(bool)
# Get hydrogen bond donor/acceptor information
for feats in mol_feats:
if feats.GetFamily() == 'Donor':
nodes = feats.GetAtomIds()
for u in nodes:
is_donor[u] = True
elif feats.GetFamily() == 'Acceptor':
nodes = feats.GetAtomIds()
for u in nodes:
is_acceptor[u] = True
return is_donor, is_acceptor
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping atom_data_field as specified in the input argument to the atom
features, which is a float32 tensor of shape (N, M), N is the number of
atoms and M is the feature size.
"""
atom_features = []
AllChem.ComputeGasteigerCharges(mol)
num_atoms = mol.GetNumAtoms()
# Get information for donor and acceptor
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
is_donor, is_acceptor = self.get_donor_acceptor_info(mol_feats)
# Get a symmetrized smallest set of smallest rings
# Following the practice from Chainer Chemistry (https://github.com/chainer/
# chainer-chemistry/blob/da2507b38f903a8ee333e487d422ba6dcec49b05/chainer_chemistry/
# dataset/preprocessors/weavenet_preprocessor.py)
sssr = Chem.GetSymmSSSR(mol)
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
# Features that can be computed directly from RDKit atom instances, which is a list
feats = self._featurizer(atom)
# Donor/acceptor indicator
feats.append(float(is_donor[i]))
feats.append(float(is_acceptor[i]))
# Count the number of rings the atom belongs to for ring size between 3 and 8
count = [0 for _ in range(3, 9)]
for ring in sssr:
ring_size = len(ring)
if i in ring and 3 <= ring_size <= 8:
count[ring_size - 3] += 1
feats.extend(count)
atom_features.append(feats)
atom_features = np.stack(atom_features)
return {self._atom_data_field: F.zerocopy_from_numpy(atom_features.astype(np.float32))}
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False): def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond. """One hot encoding for the type of a bond.
...@@ -1071,14 +1275,26 @@ class BaseBondFeaturizer(object): ...@@ -1071,14 +1275,26 @@ class BaseBondFeaturizer(object):
feat_sizes = dict() feat_sizes = dict()
self._feat_sizes = feat_sizes self._feat_sizes = feat_sizes
def feat_size(self, feat_name): def feat_size(self, feat_name=None):
"""Get the feature size for ``feat_name``. """Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns Returns
------- -------
int int
Feature size for the feature with name ``feat_name``. Feature size for the feature with name ``feat_name``. Default to None.
""" """
if feat_name is None:
assert len(self.featurizer_funcs) == 1, \
'feat_name should be provided if there are more than one features'
feat_name = list(self.featurizer_funcs.keys())[0]
if feat_name not in self.featurizer_funcs: if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format( return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name)) list(self.featurizer_funcs.keys()), feat_name))
...@@ -1165,3 +1381,101 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer): ...@@ -1165,3 +1381,101 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
bond_is_in_ring, bond_is_in_ring,
bond_stereo_one_hot] bond_stereo_one_hot]
)}) )})
# pylint: disable=E1102
class WeaveEdgeFeaturizer(object):
"""Edge featurizer in Weave.
The edge featurization is introduced in `Molecular Graph Convolutions:
Moving Beyond Fingerprints <https://arxiv.org/abs/1603.00856>`__.
This featurization is performed for a complete graph of atoms with self loops added,
which considers:
* Number of bonds between each pairs of atoms
* One-hot encoding of bond type if a bond exists between a pair of atoms
* Whether a pair of atoms belongs to a same ring
Parameters
----------
edge_data_field : str
Name for storing edge features in DGLGraphs, default to ``'e'``.
max_distance : int
Maximum number of bonds to consider between each pair of atoms.
Default to 7.
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider for one hot encoding. If None, we consider by
default single, double, triple and aromatic bonds.
"""
def __init__(self, edge_data_field='e', max_distance=7, bond_types=None):
super(WeaveEdgeFeaturizer, self).__init__()
self._edge_data_field = edge_data_field
self._max_distance = max_distance
if bond_types is None:
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
self._bond_types = bond_types
def feat_size(self):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol = Chem.MolFromSmiles('C')
feats = self(mol)[self._edge_data_field]
return feats.shape[-1]
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping self._edge_data_field to a float32 tensor of shape (N, M), where
N is the number of atom pairs and M is the feature size.
"""
# Part 1 based on number of bonds between each pair of atoms
distance_matrix = torch.from_numpy(Chem.GetDistanceMatrix(mol))
# Change shape from (V, V, 1) to (V^2, 1)
distance_matrix = distance_matrix.float().reshape(-1, 1)
# Elementwise compare if distance is bigger than 0, 1, ..., max_distance - 1
distance_indicators = (distance_matrix >
torch.arange(0, self._max_distance).float()).float()
# Part 2 for one hot encoding of bond type.
num_atoms = mol.GetNumAtoms()
bond_indicators = torch.zeros(num_atoms, num_atoms, len(self._bond_types))
for bond in mol.GetBonds():
bond_type_encoding = torch.tensor(
bond_type_one_hot(bond, allowable_set=self._bond_types)).float()
begin_atom_idx, end_atom_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_indicators[begin_atom_idx, end_atom_idx] = bond_type_encoding
bond_indicators[end_atom_idx, begin_atom_idx] = bond_type_encoding
# Reshape from (V, V, num_bond_types) to (V^2, num_bond_types)
bond_indicators = bond_indicators.reshape(-1, len(self._bond_types))
# Part 3 for whether a pair of atoms belongs to a same ring.
sssr = Chem.GetSymmSSSR(mol)
ring_mate_indicators = torch.zeros(num_atoms, num_atoms, 1)
for ring in sssr:
ring = list(ring)
num_atoms_in_ring = len(ring)
for i in range(num_atoms_in_ring):
ring_mate_indicators[ring[i], torch.tensor(ring)] = 1
ring_mate_indicators = ring_mate_indicators.reshape(-1, 1)
return {self._edge_data_field: torch.cat([distance_indicators,
bond_indicators,
ring_mate_indicators], dim=1)}
...@@ -235,6 +235,32 @@ def test_wln(): ...@@ -235,6 +235,32 @@ def test_wln():
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 3]) assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 3])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 3]) assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 3])
def test_weave():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = WeaveGNN(node_in_feats=1,
edge_in_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 50])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 50])
# Test configured setting
gnn = WeaveGNN(node_in_feats=1,
edge_in_feats=2,
num_layers=1,
hidden_feats=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
if __name__ == '__main__': if __name__ == '__main__':
test_gcn() test_gcn()
test_gat() test_gat()
...@@ -243,3 +269,4 @@ if __name__ == '__main__': ...@@ -243,3 +269,4 @@ if __name__ == '__main__':
test_mgcn_gnn() test_mgcn_gnn()
test_mpnn_gnn() test_mpnn_gnn()
test_wln() test_wln()
test_weave()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment