Unverified Commit 039fefc2 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] WLN for Reaction Prediction (#1530)

* Update

* Update

* Update

* Update

* Update

* Fix bug

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* UPdate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Udpate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Finalize
parent 70dc2ee9
......@@ -180,9 +180,10 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True)
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement |
| ---------------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 5095 | 2.3x | |
| Model | Original Implementation | DGL Implementation | Improvement |
| ---------------------------------- | ----------------------- | -------------------------- | ---------------------------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 858 (1 GPU) / 134 (8 GPUs) | 13.6x (1GPU) / 87.0x (8GPUs) |
| WLN for candidate ranking | 40122 | 22268 | 1.8x |
......@@ -41,16 +41,23 @@ Reaction Prediction
USPTO
`````
.. autoclass:: dgllife.data.USPTO
.. autoclass:: dgllife.data.USPTOCenter
:members: __getitem__, __len__
:show-inheritance:
.. autoclass:: dgllife.data.USPTORank
:members: ignore_large, __getitem__, __len__
:show-inheritance:
Adapting to New Datasets for Weisfeiler-Lehman Networks
```````````````````````````````````````````````````````
.. autoclass:: dgllife.data.WLNReactionDataset
.. autoclass:: dgllife.data.WLNCenterDataset
:members: __getitem__, __len__
.. autoclass:: dgllife.data.WLNRankDataset
:members: ignore_large, __getitem__, __len__
Protein-Ligand Binding Affinity Prediction
------------------------------------------
......
......@@ -74,6 +74,11 @@ WLN for Reaction Center Prediction
.. automodule:: dgllife.model.model_zoo.wln_reaction_center
:members:
WLN for Ranking Candidate Products
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_ranking
:members:
Protein-Ligand Binding Affinity Prediction
ACNN
......
......@@ -7,6 +7,16 @@ An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling.
This work proposes a template-free approach for reaction prediction with 2 stages:
1) Identify reaction center (pairs of atoms that will lose a bond or form a bond)
2) Enumerate the possible combinations of bond changes and rank the corresponding candidate products
We provide a jupyter notebook for walking through a demonstration with our pre-trained models. You can
download it with `wget https://data.dgl.ai/dgllife/reaction_prediction_pretrained.ipynb` and you need to put it
in this directory. Below we visualize a reaction prediction by the model:
![](https://data.dgl.ai/dgllife/wln_reaction.png)
## Dataset
The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
......@@ -29,53 +39,104 @@ whose reaction centers have all been selected.
We use GPU whenever possible. To train the model with default options, simply do
```bash
python find_reaction_center.py
python find_reaction_center_train.py
```
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 1/50, iter 8150/20452 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | loss 8.6722 | grad norm 14.0833
```
Once the training process starts, the progress will be printed out in the terminal as follows:
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `reaction_center_config`), we save a checkpoint of
the model and evaluate the model on the validation set. The evaluation result is formatted as follows, where `total samples x` means
the we have trained the model on `x` samples and `acc@k` means top-k accuracy:
```bash
Epoch 1/50, iter 8150/20452 | time/minibatch 0.0260 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | time/minibatch 0.0260 | loss 8.6722 | grad norm 14.0833
total samples 800000, (epoch 2/35, iter 2443/2557) | acc@12 0.9278 | acc@16 0.9419 | acc@20 0.9496 | acc@40 0.9596 | acc@80 0.9596 |
```
After an epoch of training is completed, we evaluate the model on the validation set and
print the evaluation results as follows:
All model check points and evaluation results can be found under `center_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Multi-GPU Training
By default we use one GPU only. We also allow multi-gpu training. To use GPUs with ids `id1,id2,...`, do
```bash
Epoch 4/50, validation | acc@10 0.8213 | acc@20 0.9016 |
python find_reaction_center_train.py --gpus id1,id2,...
```
By default, we store the model per 10000 iterations in `center_results`.
A summary of the training speedup with the DGL implementation is presented below.
**Speedup**: For an epoch of training, our implementation takes about 5095s for the first epoch while the authors'
implementation takes about 11657s, which is roughly a speedup by 2.3x.
| Item | Training time (s/epoch) | Speedup |
| ----------------------- | ----------------------- | ------- |
| Authors' implementation | 11657 | 1x |
| DGL with 1 gpu | 858 | 13.6x |
| DGL with 2 gpus | 443 | 26.3x |
| DGL with 4 gpus | 243 | 48.0x |
| DGL with 8 gpus | 134 | 87.0x |
### Evaluation
```bash
python find_reaction_center_eval.py --model-path X
```
For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`center_results/model_800000.pkl`. The evaluation results will be stored at `center_results/test_eval.txt`.
For model evaluation, we can choose whether to exclude reactants not contributing heavy atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do
```bash
python find_reaction_center.py --easy
python find_reaction_center_eval.py --easy
```
A summary of the model performance is as follows:
A summary of the model performance of various settings is as follows:
| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation | 88.8 | 91.6 | 92.9 |
| Easy evaluation | 91.0 | 93.7 | 94.9 |
| Hard evaluation from authors' code | 87.7 | 90.6 | 92.1 |
| Easy evaluation from authors' code | 90.0 | 92.8 | 94.2 |
| Hard evaluation | 88.9 | 91.7 | 93.1 |
| Easy evaluation | 91.2 | 93.8 | 95.0 |
| Hard evaluation for model trained on 8 gpus | 88.1 | 91.0 | 92.5 |
| Easy evaluation for model trained on 8 gpus | 90.3 | 93.3 | 94.6 |
1. We are able to match the results reported from authors' code for both single-gpu and multi-gpu training
2. While multi-gpu training provides a great speedup, the performance with the default hyperparameters drops slightly.
### Data Pre-processing with Multi-Processing
By default we use 32 processes for data pre-processing. If you encounter an error with
`BrokenPipeError: [Errno 32] Broken pipe`, you can specify a smaller number of processes with
```bash
python find_reaction_center_train.py -np X
```
```bash
python find_reaction_center_eval.py -np X
```
where `X` is the number of processes that you would like to use.
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do
```bash
python find_reaction_center.py -p
python find_reaction_center_eval.py
```
### Adapting to a new dataset.
### Adapting to a New Dataset
New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:
......@@ -89,10 +150,127 @@ In addition, atom mapping information is provided.
You can then train a model on new datasets with
```bash
python find_reaction_center.py --train-path X --val-path Y --test-path Z
python find_reaction_center_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation as described above.
For evaluation,
```bash
python find_reaction_center_eval.py --eval-path Z
```
where `Z` is the path to the new test set as described above.
## Candidate Ranking
### Additional Dependency
In addition to RDKit, MolVS is an alternative for comparing whether two molecules are the same after sanitization.
- [molvs](https://molvs.readthedocs.io/en/latest/)
### Modeling
For candidate ranking, we assume that a model has been trained for reaction center prediction first.
The pipeline for predicting candidate products given a reaction proceeds as follows:
1. Select top-k bond changes for atom pairs in the reactants, ranked by the model for reaction center prediction.
By default, we use k=80 and exclude reactants not contributing heavy atoms to the ground truth product in
selecting top-k bond changes as in the paper.
2. Filter out candidate bond changes for bonds that are already in the reactants
3. Enumerate possible combinations of atom pairs with up to C pairs, which reflects the number of bond changes
(losing or forming a bond) in reactions. A statistical analysis in USPTO suggests that setting it to 5 is enough.
4. Filter out invalid combinations where 1) atoms in candidate bond changes are not connected or 2) an atom pair is
predicted to have different types of bond changes
(e.g. two atoms are predicted simultaneously to form a single and double bond) or 3) valence constraints are violated.
5. Apply the candidate bond changes for each valid combination and get the corresponding candidate products.
6. Construct molecular graphs for the reactants and candidate products, featurize their atoms and bonds.
7. Apply a Weisfeiler-Lehman Network to the molecular graphs for reactants and candidate products and score them
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python candidate_ranking_train.py -cmp X
```
where `X` is the path to a trained model for reaction center prediction. You can use our
pre-trained model by not specifying `-cmp`.
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 6/6, iter 16439/20061 | time 1.1124 | accuracy 0.8500 | grad norm 5.3218
Epoch 6/6, iter 16440/20061 | time 1.1124 | accuracy 0.9500 | grad norm 2.1163
```
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `candidate_ranking_config`),
we save a checkpoint of the model and evaluate the model on the validation set. The evaluation result is formatted
as follows, where `total samples x` means that we have trained the model for `x` samples, `acc@k` means top-k accuracy,
`gfound` means the proportion of reactions where the ground truth product can be recovered by the ground truth bond changes.
We perform the evaluation based on RDKit-sanitized molecule equivalence (marked with `[strict]`) and MOLVS-sanitized
molecule equivalence (marked with `[molvs]`).
```bash
total samples 100000, (epoch 1/20, iter 5000/20061)
[strict] acc@1: 0.7732 acc@2: 0.8466 acc@3: 0.8763 acc@5: 0.8987 gfound 0.9864
[molvs] acc@1: 0.7779 acc@2: 0.8523 acc@3: 0.8826 acc@5: 0.9057 gfound 0.9953
```
All model check points and evaluation results can be found under `candidate_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples in total. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Evaluation
```bash
python candidate_ranking_eval.py --model-path X -cmp Y
```
where `X` is the path to a trained model for candidate ranking and `Y` is the path to a trained model
for reaction center prediction. For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`candidate_results/model_800000.pkl`. The evaluation results will be stored at `candidate_results/test_eval.txt`. As
in training, you can use our pre-trained model by not specifying `-cmp`.
A summary of the model performance of various settings is as follows:
| Item | Top 1 accuracy | Top 2 accuracy | Top 3 accuracy | Top 5 accuracy |
| -------------------------- | -------------- | -------------- | -------------- | -------------- |
| Authors' strict evaluation | 85.6 | 90.5 | 92.8 | 93.4 |
| DGL's strict evaluation | 85.6 | 90.0 | 91.7 | 92.9 |
| Authors' molvs evaluation | 86.2 | 91.2 | 92.8 | 94.2 |
| DGL's molvs evaluation | 86.1 | 90.6 | 92.4 | 93.6 |
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model,
simply do
```bash
python candidate_ranking_eval.py
```
### Adapting to a New Dataset
You can train a model on new datasets with
```bash
python candidate_ranking_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation set as described in the `Reaction Center Prediction` section.
For evaluation,
```bash
python candidate_ranking_train.py --eval-path Z
```
where `X`, `Y`, `Z` are paths to the new training/validation/test set as described above.
where `Z` is the path to the new test set as described in the `Reaction Center Prediction` section.
## References
......
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking, load_pretrained
from torch.utils.data import DataLoader
from configure import candidate_ranking_config, reaction_center_config
from utils import mkdir_p, prepare_reaction_center, collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['test_path'] is None:
test_set = USPTORank(
subset='test', candidate_bond_path=path_to_candidate_bonds['test'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
test_set = WLNRankDataset(
raw_file_path=args['test_path'],
candidate_bond_path=path_to_candidate_bonds['test'], mode='test',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=1, collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
if args['model_path'] is None:
model = load_pretrained('wln_rank_uspto')
else:
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
prediction_summary = candidate_ranking_eval(args, model, test_loader)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(prediction_summary)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set. '
'If None, we will use the default test set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=32,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
import time
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from configure import reaction_center_config, candidate_ranking_config
from utils import prepare_reaction_center, mkdir_p, set_seed, collate_rank_train, \
collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['train_path'] is None:
train_set = USPTORank(
subset='train', candidate_bond_path=path_to_candidate_bonds['train'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
else:
train_set = WLNRankDataset(
raw_file_path=args['train_path'],
candidate_bond_path=path_to_candidate_bonds['train'], mode='train',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
train_set.ignore_large()
if args['val_path'] is None:
val_set = USPTORank(
subset='val', candidate_bond_path=path_to_candidate_bonds['val'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
val_set = WLNRankDataset(
raw_file_path=args['val_path'],
candidate_bond_path=path_to_candidate_bonds['val'], mode='val',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_rank_train,
shuffle=True, num_workers=args['num_workers'])
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers']).to(args['device'])
criterion = CrossEntropyLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
acc_sum = 0
grad_norm_sum = 0
dur = []
total_samples = 0
for epoch in range(args['num_epochs']):
t0 = time.time()
model.train()
for batch_id, batch_data in enumerate(train_loader):
batch_reactant_graphs, batch_product_graphs, \
batch_combo_scores, batch_labels, batch_num_candidate_products = batch_data
batch_combo_scores = batch_combo_scores.to(args['device'])
batch_labels = batch_labels.to(args['device'])
reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(args['device'])
reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(args['device'])
product_node_feats = batch_product_graphs.ndata.pop('hv').to(args['device'])
product_edge_feats = batch_product_graphs.edata.pop('he').to(args['device'])
pred = model(reactant_graph=batch_reactant_graphs,
reactant_node_feats=reactant_node_feats,
reactant_edge_feats=reactant_edge_feats,
product_graphs=batch_product_graphs,
product_node_feats=product_node_feats,
product_edge_feats=product_edge_feats,
candidate_scores=batch_combo_scores,
batch_num_candidate_products=batch_num_candidate_products)
# Check if the ground truth candidate has the highest score
batch_loss = 0
product_graph_start = 0
for i in range(len(batch_num_candidate_products)):
product_graph_end = product_graph_start + batch_num_candidate_products[i]
reaction_pred = pred[product_graph_start:product_graph_end, :]
acc_sum += float(reaction_pred.max(dim=0)[1].detach().cpu().data.item() == 0)
batch_loss += criterion(reaction_pred.reshape(1, -1), batch_labels[i, :])
product_graph_start = product_graph_end
grad_norm_sum += optimizer.backward_and_step(batch_loss)
total_samples += args['batch_size']
if total_samples % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time {:.4f} | ' \
'accuracy {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every'],
(sum(dur) + time.time() - t0) / total_samples * args['print_every'],
acc_sum / args['print_every'],
grad_norm_sum / args['print_every'])
print(progress)
acc_sum = 0
grad_norm_sum = 0
if total_samples % args['decay_every'] == 0:
dur.append(time.time() - t0)
old_lr = optimizer.lr
optimizer.decay_lr(args['lr_decay_factor'])
new_lr = optimizer.lr
print('Learning rate decayed from {:.4f} to {:.4f}'.format(old_lr, new_lr))
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d})\n'.format(
total_samples, epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every']) + candidate_ranking_eval(args, model, val_loader)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
t0 = time.time()
model.train()
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=100,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
......@@ -10,10 +10,33 @@ reaction_center_config = {
'n_layers': 3,
'n_tasks': 5,
'lr': 0.001,
'num_epochs': 25,
'num_epochs': 18,
'print_every': 50,
'decay_every': 10000, # Learning rate decay
'lr_decay_factor': 0.9,
'top_ks': [6, 8, 10],
'top_ks_val': [12, 16, 20, 40, 80],
'top_ks_test': [6, 8, 10],
'max_k': 80
}
# Configuration for candidate ranking
candidate_ranking_config = {
'batch_size': 4,
'hidden_size': 500,
'num_encode_gnn_layers': 3,
'max_norm': 50.0,
'node_in_feats': 89,
'edge_in_feats': 5,
'lr': 0.001,
'num_epochs': 6,
'print_every': 20,
'decay_every': 100000,
'lr_decay_factor': 0.9,
'top_ks': [1, 2, 3, 5],
'max_k': 10,
'max_num_change_combos_per_reaction_train': 150,
'max_num_change_combos_per_reaction_eval': 1500,
'num_candidate_bond_changes': 16
}
candidate_ranking_config['max_norm'] = candidate_ranking_config['max_norm'] * \
candidate_ranking_config['batch_size']
import numpy as np
import time
import torch
from dgllife.data import USPTO, WLNReactionDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from utils import setup, collate, reaction_center_prediction, \
rough_eval_on_a_loader, reaction_center_final_eval
def main(args):
setup(args)
if args['train_path'] is None:
train_set = USPTO('train')
else:
train_set = WLNReactionDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin')
if args['val_path'] is None:
val_set = USPTO('val')
else:
val_set = WLNReactionDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin')
if args['test_path'] is None:
test_set = USPTO('test')
else:
test_set = WLNReactionDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin')
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
if args['pre_trained']:
model = load_pretrained('wln_center_uspto').to(args['device'])
args['num_epochs'] = 0
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor'])
total_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += 1
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
grad_norm_sum += grad_norm
optimizer.step()
scheduler.step()
if total_iter % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
(np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
grad_norm_sum / args['print_every'])
grad_norm_sum = 0
loss_sum = 0
print(progress)
if total_iter % args['decay_every'] == 0:
torch.save(model.state_dict(), args['result_path'] + '/model.pkl')
dur.append(time.time() - t0)
print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
rough_eval_on_a_loader(args, model, val_loader))
del train_loader
del val_loader
del train_set
del val_set
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(args, model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/results.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to training results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('-p', '--pre-trained', action='store_true', default=False,
help='If true, we will directly evaluate a '
'pretrained model on the test set.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
main(args)
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.utils.data import DataLoader
from utils import reaction_center_final_eval, set_seed, collate_center, mkdir_p
def main(args):
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
# Set current device
torch.cuda.set_device(args['device'])
if args['test_path'] is None:
test_set = USPTOCenter('test', num_processes=args['num_processes'])
else:
test_set = WLNCenterDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin',
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
if args['model_path'] is None:
model = load_pretrained('wln_center_uspto')
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(
args, args['top_ks_test'], model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Evaluation')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path where we saved model training and evaluation results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing heavy atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_test']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks_test, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_test']))
mkdir_p(args['result_path'])
main(args)
import numpy as np
import time
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import collate_center, reaction_center_prediction, \
reaction_center_final_eval, mkdir_p, set_seed, synchronize, get_center_subset, \
count_parameters
def load_dataset(args):
if args['train_path'] is None:
train_set = USPTOCenter('train', num_processes=args['num_processes'])
else:
train_set = WLNCenterDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin',
num_processes=args['num_processes'])
if args['val_path'] is None:
val_set = USPTOCenter('val', num_processes=args['num_processes'])
else:
val_set = WLNCenterDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin',
num_processes=args['num_processes'])
return train_set, val_set
def main(rank, dev_id, args):
set_seed()
# Remove the line below will result in problems for multiprocess
if args['num_devices'] > 1:
torch.set_num_threads(1)
if dev_id == -1:
args['device'] = torch.device('cpu')
else:
args['device'] = torch.device('cuda:{}'.format(dev_id))
# Set current device
torch.cuda.set_device(args['device'])
train_set, val_set = load_dataset(args)
get_center_subset(train_set, rank, args['num_devices'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
model.train()
if rank == 0:
print('# trainable parameters in the model: ', count_parameters(model))
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
if args['num_devices'] <= 1:
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
optimizer, max_grad_norm=args['max_norm'])
total_iter = 0
rank_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += args['num_devices']
rank_iter += 1
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
grad_norm_sum += optimizer.backward_and_step(loss)
if rank_iter % args['print_every'] == 0 and rank == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
print(progress)
grad_norm_sum = 0
loss_sum = 0
if total_iter % args['decay_every'] == 0:
optimizer.decay_lr(args['lr_decay_factor'])
if total_iter % args['decay_every'] == 0 and rank == 0:
if epoch >= 1:
dur.append(time.time() - t0)
print('Training time per {:d} iterations: {:.4f}'.format(
rank_iter, np.mean(dur)))
total_samples = total_iter * args['batch_size']
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
t0 = time.time()
model.train()
synchronize(args['num_devices'])
def run(rank, dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=args['num_devices'],
rank=rank)
assert torch.distributed.get_rank() == rank
main(rank, dev_id, args)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Training')
parser.add_argument('--gpus', default='0', type=str,
help='To use multi-gpu training, '
'pass multiple gpu ids with --gpus id1,id2,...')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
parser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
parser.add_argument('--master-port', type=str, default='12345',
help='master port')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_val']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_val']))
mkdir_p(args['result_path'])
devices = list(map(int, args['gpus'].split(',')))
args['num_devices'] = len(devices)
if len(devices) == 1:
device_id = devices[0] if torch.cuda.is_available() else -1
main(0, device_id, args)
else:
if (args['train_path'] is not None) or (args['val_path'] is not None):
print('First pass for constructing DGLGraphs with multiprocessing')
load_dataset(args)
# Subprocesses are not allowed for daemon mode
args['num_processes'] = 1
# With multi-gpu training, the batch size increases and we need to
# increase learning rate accordingly.
args['lr'] = args['lr'] * args['num_devices']
mp = torch.multiprocessing.get_context('spawn')
procs = []
for id, device_id in enumerate(devices):
print('Preparing for gpu {:d}/{:d}'.format(id + 1, args['num_devices']))
procs.append(mp.Process(target=run, args=(
id, device_id, args), daemon=True))
procs[-1].start()
for p in procs:
p.join()
......@@ -4,9 +4,22 @@ import numpy as np
import os
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from copy import deepcopy
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import load_pretrained, WLNReactionCenter
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
try:
from molvs import Standardizer
except ImportError as e:
print('MolVS is not installed, which is required for candidate ranking')
def mkdir_p(path):
"""Create a folder for the given path.
......@@ -24,44 +37,195 @@ def mkdir_p(path):
else:
raise
def setup(args, seed=0):
"""Setup for the experiment:
1. Decide whether to use CPU or GPU for training
2. Fix random seed for python, NumPy and PyTorch.
def set_seed(seed=0):
"""Fix random seed.
Parameters
----------
seed : int
Random seed to use.
Returns
-------
args
Updated configuration
Random seed to use. Default to 0.
"""
assert args['max_k'] >= max(args['top_ks']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks']))
if torch.cuda.is_available():
args['device'] = 'cuda:0'
else:
args['device'] = 'cpu'
# Set random seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
mkdir_p(args['result_path'])
def count_parameters(model):
"""Get the number of trainable parameters in the model.
Parameters
----------
model : nn.Module
The model
return args
Returns
-------
int
Number of trainable parameters in the model
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def collate(data):
"""Collate multiple datapoints
def get_center_subset(dataset, subset_id, num_subsets):
"""Get subset for reaction center identification.
Parameters
----------
dataset : WLNCenterDataset
Dataset for reaction center prediction with WLN
subset_id : int
Index for the subset
num_subsets : int
Number of total subsets
"""
if num_subsets == 1:
return
total_size = len(dataset)
subset_size = total_size // num_subsets
start = subset_id * subset_size
end = (subset_id + 1) * subset_size
dataset.mols = dataset.mols[start:end]
dataset.reactions = dataset.reactions[start:end]
dataset.graph_edits = dataset.graph_edits[start:end]
dataset.reactant_mol_graphs = dataset.reactant_mol_graphs[start:end]
dataset.atom_pair_features = [None for _ in range(subset_size)]
dataset.atom_pair_labels = [None for _ in range(subset_size)]
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
model : nn.Module
Model being trained
lr : float
Initial learning rate
optimizer : torch.optim.Optimizer
model optimizer
num_accum_times : int
Number of times for accumulating gradients
max_grad_norm : float or None
If not None, gradient clipping will be performed
"""
def __init__(self, model, lr, optimizer, num_accum_times=1, max_grad_norm=None):
super(Optimizer, self).__init__()
self.model = model
self.lr = lr
self.optimizer = optimizer
self.step_count = 0
self.num_accum_times = num_accum_times
self.max_grad_norm = max_grad_norm
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def _clip_grad_norm(self):
grad_norm = None
if self.max_grad_norm is not None:
grad_norm = clip_grad_norm_(self.model.parameters(),
self.max_grad_norm)
return grad_norm
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
Returns
-------
grad_norm : float
Gradient norm. If self.max_grad_norm is None, None will be returned.
"""
self.step_count += 1
loss.backward()
if self.step_count % self.num_accum_times == 0:
grad_norm = self._clip_grad_norm()
self.optimizer.step()
self._reset()
return grad_norm
else:
return 0
def decay_lr(self, decay_rate):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
model : nn.Module
Model being trained
lr : float
Initial learning rate
optimizer : torch.optim.Optimizer
model optimizer
max_grad_norm : float or None
If not None, gradient clipping will be performed.
"""
def __init__(self, n_processes, model, lr, optimizer, max_grad_norm=None):
super(MultiProcessOptimizer, self).__init__(lr=lr, model=model, optimizer=optimizer,
max_grad_norm=max_grad_norm)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
Returns
-------
grad_norm : float
Gradient norm. If self.max_grad_norm is None, None will be returned.
"""
loss.backward()
self._sync_gradient()
grad_norm = self._clip_grad_norm()
self.optimizer.step()
self._reset()
return grad_norm
def synchronize(num_gpus):
"""Synchronize all processes for multi-gpu training.
Parameters
----------
num_gpus : int
Number of gpus used
"""
if num_gpus > 1:
dist.barrier()
def collate_center(data):
"""Collate multiple datapoints for reaction center prediction
Parameters
----------
......@@ -77,8 +241,6 @@ def collate(data):
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
mols : list of rdkit.Chem.rdchem.Mol
List of RDKit molecule instances for the reactants.
batch_mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs.
batch_complete_graphs : DGLGraph
......@@ -86,7 +248,7 @@ def collate(data):
batch_atom_pair_labels : float32 tensor of shape (V, 10)
Labels of atom pairs in the batch of graphs.
"""
reactions, graph_edits, mols, mol_graphs, complete_graphs, \
reactions, graph_edits, mol_graphs, complete_graphs, \
atom_pair_feats, atom_pair_labels = map(list, zip(*data))
batch_mol_graphs = dgl.batch(mol_graphs)
......@@ -100,7 +262,7 @@ def collate(data):
batch_atom_pair_labels = torch.cat(atom_pair_labels, dim=0)
return reactions, graph_edits, mols, batch_mol_graphs, \
return reactions, graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels
def reaction_center_prediction(device, model, mol_graphs, complete_graphs):
......@@ -137,55 +299,88 @@ def reaction_center_prediction(device, model, mol_graphs, complete_graphs):
return model(mol_graphs, complete_graphs, node_feats, edge_feats, node_pair_feats)
def rough_eval(complete_graphs, preds, labels, num_correct):
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
labels_i = labels[start:end, :].flatten()
for k in num_correct.keys():
topk_values, topk_indices = torch.topk(preds_i, k)
is_correct = labels_i[topk_indices].sum() == labels_i.sum().float().cpu().data.item()
num_correct[k].append(is_correct)
start = end
def rough_eval_on_a_loader(args, model, data_loader):
"""A rough evaluation of model performance in the middle of training.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
id_to_bond_change = {v: k for k, v in bond_change_to_id.items()}
num_change_types = len(bond_change_to_id)
For final evaluation, we will eliminate some possibilities based on prior knowledge.
def get_candidate_bonds(reaction, preds, num_nodes, max_k, easy, include_scores=False):
"""Get candidate bonds for a reaction.
Parameters
----------
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
reaction : str
Reaction
preds : float32 tensor of shape (E * 5)
E for the number of edges in a complete graph and 5 for the number of possible
bond changes.
num_nodes : int
Number of nodes in the graph.
max_k : int
Maximum number of atom pairs to be selected.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
include_scores : bool
Whether to include the scores for the atom pairs selected. Default to False.
Returns
-------
str
Message for evluation result.
list of 3-tuples or 4-tuples
The first three elements in a tuple separately specify the first atom,
the second atom and the type for bond change. If include_scores is True,
the score for the prediction will be included as a fourth element.
"""
model.eval()
num_correct = {k: [] for k in args['top_ks']}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
rough_eval(batch_complete_graphs, biased_pred, batch_atom_pair_labels, num_correct)
# Decide which atom-pairs will be considered.
reaction_atoms = []
reaction_bonds = defaultdict(bool)
reactants, _, product = reaction.split('>')
product_mol = Chem.MolFromSmiles(product)
product_atoms = set([atom.GetAtomMapNum() for atom in product_mol.GetAtoms()])
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, np.mean(correct_count))
for reactant in reactants.split('.'):
reactant_mol = Chem.MolFromSmiles(reactant)
reactant_atoms = [atom.GetAtomMapNum() for atom in reactant_mol.GetAtoms()]
# In the hard mode, all reactant atoms will be included.
# In the easy mode, only reactants contributing atoms to the product will be included.
if (len(set(reactant_atoms) & product_atoms) > 0) or (not easy):
reaction_atoms.extend(reactant_atoms)
for bond in reactant_mol.GetBonds():
end_atoms = sorted([bond.GetBeginAtom().GetAtomMapNum(),
bond.GetEndAtom().GetAtomMapNum()])
bond = tuple(end_atoms + [bond.GetBondTypeAsDouble()])
# Bookkeep bonds already in reactants
reaction_bonds[bond] = True
candidate_bonds = []
topk_values, topk_indices = torch.topk(preds, max_k)
for j in range(max_k):
preds_j = topk_indices[j].cpu().item()
# A bond change can be either losing the bond or forming a
# single, double, triple or aromatic bond
change_id = preds_j % num_change_types
change_type = id_to_bond_change[change_id]
pair_id = preds_j // num_change_types
# Atom map numbers
atom1 = pair_id // num_nodes + 1
atom2 = pair_id % num_nodes + 1
# Avoid duplicates and an atom cannot form a bond with itself
if atom1 >= atom2:
continue
if atom1 not in reaction_atoms:
continue
if atom2 not in reaction_atoms:
continue
candidate = (int(atom1), int(atom2), float(change_type))
if reaction_bonds[candidate]:
continue
if include_scores:
candidate += (float(topk_values[j].cpu().item()),)
candidate_bonds.append(candidate)
return msg
return candidate_bonds
def eval(complete_graphs, preds, reactions, graph_edits, num_correct, max_k, easy):
def reaction_center_eval(complete_graphs, preds, reactions,
graph_edits, num_correct, max_k, easy):
"""Evaluate top-k accuracies for reaction center prediction.
Parameters
......@@ -211,57 +406,13 @@ def eval(complete_graphs, preds, reactions, graph_edits, num_correct, max_k, eas
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
id_to_bond_change = {v: k for k, v in bond_change_to_id.items()}
num_change_types = len(bond_change_to_id)
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
# Decide which atom-pairs will be considered.
reaction_i = reactions[i]
reaction_atoms_i = []
reaction_bonds_i = defaultdict(bool)
reactants_i, _, product_i = reaction_i.split('>')
product_mol_i = Chem.MolFromSmiles(product_i)
product_atoms_i = set([atom.GetAtomMapNum() for atom in product_mol_i.GetAtoms()])
for reactant in reactants_i.split('.'):
reactant_mol = Chem.MolFromSmiles(reactant)
reactant_atoms = [atom.GetAtomMapNum() for atom in reactant_mol.GetAtoms()]
if (len(set(reactant_atoms) & product_atoms_i) > 0) or (not easy):
reaction_atoms_i.extend(reactant_atoms)
for bond in reactant_mol.GetBonds():
end_atoms = sorted([bond.GetBeginAtom().GetAtomMapNum(),
bond.GetEndAtom().GetAtomMapNum()])
bond = tuple(end_atoms + [bond.GetBondTypeAsDouble()])
reaction_bonds_i[bond] = True
num_nodes = complete_graphs.batch_num_nodes[i]
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
candidate_bonds = []
topk_values, topk_indices = torch.topk(preds_i, max_k)
for j in range(max_k):
preds_i_j = topk_indices[j].cpu().item()
# A bond change can be either losing the bond or forming a
# single, double, triple or aromatic bond
change_id = preds_i_j % num_change_types
change_type = id_to_bond_change[change_id]
pair_id = preds_i_j // num_change_types
atom1 = pair_id // num_nodes + 1
atom2 = pair_id % num_nodes + 1
# Avoid duplicates and an atom cannot form a bond with itself
if atom1 >= atom2:
continue
if atom1 not in reaction_atoms_i:
continue
if atom2 not in reaction_atoms_i:
continue
candidate = (int(atom1), int(atom2), float(change_type))
if reaction_bonds_i[candidate]:
continue
candidate_bonds.append(candidate)
candidate_bonds = get_candidate_bonds(
reactions[i], preds[start:end, :].flatten(),
complete_graphs.batch_num_nodes[i], max_k, easy)
gold_bonds = []
gold_edits = graph_edits[i]
......@@ -275,11 +426,13 @@ def eval(complete_graphs, preds, reactions, graph_edits, num_correct, max_k, eas
num_correct[k] += 1
start = end
def reaction_center_final_eval(args, model, data_loader, easy):
def reaction_center_final_eval(args, top_ks, model, data_loader, easy):
"""Final evaluation of model performance.
args : dict
Configurations fot the experiment.
top_ks : list of int
Options for top-k evaluation
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
......@@ -294,18 +447,763 @@ def reaction_center_final_eval(args, model, data_loader, easy):
Summary of the top-k evaluation.
"""
model.eval()
num_correct = {k: 0 for k in args['top_ks']}
num_correct = {k: 0 for k in top_ks}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
eval(batch_complete_graphs, biased_pred, batch_reactions,
reaction_center_eval(batch_complete_graphs, biased_pred, batch_reactions,
batch_graph_edits, num_correct, args['max_k'], easy)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, correct_count / len(data_loader.dataset))
return msg
return msg + '\n'
def output_candidate_bonds_for_a_reaction(info, max_k):
"""Prepare top-k atom pairs for each reaction as candidate bonds
Parameters
----------
info : 3-tuple for a reaction
Consists of the reaction, the scores for atom-pairs in reactants
and the number of nodes in reactants.
max_k : int
Maximum number of atom pairs to be selected.
Returns
-------
candidate_string : str
String representing candidate bonds for a reaction. Each candidate
bond is of format 'atom1 atom2 change_type score'.
"""
reaction, preds, num_nodes = info
# Note that we use the easy mode by default, which is also the
# setting in the paper.
candidate_bonds = get_candidate_bonds(reaction, preds, num_nodes, max_k,
easy=True, include_scores=True)
candidate_string = ''
for candidate in candidate_bonds:
# A 4-tuple consisting of the atom mapping number of atom 1,
# atom 2, the bond change type and the score
candidate_string += '{} {} {:.1f} {:.3f};'.format(
candidate[0], candidate[1], candidate[2], candidate[3])
candidate_string += '\n'
return candidate_string
def prepare_reaction_center(args, reaction_center_config):
"""Use a trained model for reaction center prediction to prepare candidate bonds.
Parameters
----------
args : dict
Configuration for the experiment.
reaction_center_config : dict
Configuration for the experiment on reaction center prediction.
Returns
-------
path_to_candidate_bonds : dict
Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds.
"""
if args['center_model_path'] is None:
reaction_center_model = load_pretrained('wln_center_uspto').to(args['device'])
else:
reaction_center_model = WLNReactionCenter(
node_in_feats=reaction_center_config['node_in_feats'],
edge_in_feats=reaction_center_config['edge_in_feats'],
node_pair_in_feats=reaction_center_config['node_pair_in_feats'],
node_out_feats=reaction_center_config['node_out_feats'],
n_layers=reaction_center_config['n_layers'],
n_tasks=reaction_center_config['n_tasks'])
reaction_center_model.load_state_dict(
torch.load(args['center_model_path'])['model_state_dict'])
reaction_center_model = reaction_center_model.to(args['device'])
reaction_center_model.eval()
path_to_candidate_bonds = dict()
for subset in ['train', 'val', 'test']:
if '{}_path'.format(subset) not in args:
continue
path_to_candidate_bonds[subset] = args['result_path'] + \
'/{}_candidate_bonds.txt'.format(subset)
if os.path.isfile(path_to_candidate_bonds[subset]):
continue
print('Processing subset {}...'.format(subset))
print('Stage 1/3: Loading dataset...')
if args['{}_path'.format(subset)] is None:
dataset = USPTOCenter(subset, num_processes=args['num_processes'])
else:
dataset = WLNCenterDataset(raw_file_path=args['{}_path'.format(subset)],
mol_graph_path='{}.bin'.format(subset),
num_processes=args['num_processes'])
dataloader = DataLoader(dataset, batch_size=args['reaction_center_batch_size'],
collate_fn=collate_center, shuffle=False)
print('Stage 2/3: Performing model prediction...')
output_strings = []
for batch_id, batch_data in enumerate(dataloader):
print('Computing candidate bonds for batch {:d}/{:d}'.format(
batch_id + 1, len(dataloader)))
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], reaction_center_model,
batch_mol_graphs, batch_complete_graphs)
batch_size = len(batch_reactions)
start = 0
for i in range(batch_size):
end = start + batch_complete_graphs.batch_num_edges[i]
output_strings.append(output_candidate_bonds_for_a_reaction(
(batch_reactions[i], biased_pred[start:end, :].flatten(),
batch_complete_graphs.batch_num_nodes[i]), reaction_center_config['max_k']
))
start = end
print('Stage 3/3: Output candidate bonds...')
with open(path_to_candidate_bonds[subset], 'w') as f:
for candidate_string in output_strings:
f.write(candidate_string)
del dataset
del dataloader
del reaction_center_model
return path_to_candidate_bonds
def collate_rank_train(data):
"""Collate multiple datapoints for candidate product ranking during training
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of DGLGraphs for reactants and candidate
products, scores for candidate products by the model for reaction center prediction,
and labels for candidate products.
Returns
-------
batch_reactant_graphs : DGLGraph
DGLGraph for a batch of batch_size reactants.
product_graphs : DGLGraph
DGLGraph for a batch of B candidate products
combo_scores : float32 tensor of shape (B, 1)
Scores for candidate products by the model for reaction center prediction.
labels : int64 tensor of shape (N, 1)
Indices for the true candidate product across reactions, which is always 0
with pre-processing. N is for the number of reactions.
batch_num_candidate_products : list of int
Number of candidate products for the reactions in this batch.
"""
batch_graphs, batch_combo_scores, batch_labels = map(list, zip(*data))
batch_reactant_graphs = dgl.batch([g_list[0] for g_list in batch_graphs])
batch_num_candidate_products = []
batch_product_graphs = []
for g_list in batch_graphs:
batch_num_candidate_products.append(len(g_list) - 1)
batch_product_graphs.extend(g_list[1:])
batch_product_graphs = dgl.batch(batch_product_graphs)
batch_combo_scores = torch.cat(batch_combo_scores, dim=0)
batch_labels = torch.cat(batch_labels, dim=0)
return batch_reactant_graphs, batch_product_graphs, batch_combo_scores, batch_labels, \
batch_num_candidate_products
def collate_rank_eval(data):
"""Collate multiple datapoints for candidate product ranking during evaluation
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of DGLGraphs for reactants and candidate
products, scores for candidate products by the model for reaction center prediction,
and valid combos of candidate bond changes, one for each candidate product.
Returns
-------
batch_reactant_graph : DGLGraph
DGLGraph for a batch of batch_size reactants.
None will be returned if no valid candidate products exist.
batch_product_graphs : DGLGraph
DGLGraph for a batch of B candidate products.
None will be returned if no valid candidate products exist.
batch_combo_scores : float32 tensor of shape (B, 1)
Scores for candidate products by the model for reaction center prediction.
None will be returned if no valid candidate products exist.
valid_candidate_combos_list : list of list
valid_candidate_combos_list[i] gives valid combos of candidate bond changes for the
i-th reaction. valid_candidate_combos_list[i][j] gives a list of tuples, which is
the j-th valid combo of candidate bond changes for the reaction. Each tuple is of form
(atom1, atom2, change_type, score). atom1, atom2 are the atom mapping numbers - 1 of the
two end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond. None will be returned if no valid candidate
products exist.
reactant_mols_list : list of rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants in the batch.
None will be returned if no valid candidate products exist.
real_bond_changes_list : list of list
real_bond_changes_list[i] gives the ground truth bond changes in the i-th reaction,
which is a list of tuples. Each tuple is of form (atom1, atom2, change_type). atom1,
atom2 are the atom mapping numbers - 1 of the two end atoms. change_type can be
0, 1, 2, 3, 1.5, separately for losing a bond, forming a single, double, triple, and
aromatic bond. None will be returned if no valid candidate products exist.
product_mols_list : list of rdkit.Chem.rdchem.Mol
RDKit molecule instance for the candidate products in each reaction.
None will be returned if no valid candidate products exist.
batch_num_candidate_products : list of int
Number of candidate products for the reactions in this batch.
"""
batch_graphs, batch_combo_scores, batch_valid_candidate_combos, \
batch_reactant_mols, batch_real_bond_changes, batch_product_mols = map(list, zip(*data))
batch_reactant_graphs = []
batch_product_graphs = []
combo_scores_list = []
valid_candidate_combos_list = []
reactant_mols_list = []
real_bond_changes_list = []
product_mols_list = []
batch_num_candidate_products = []
for i in range(len(batch_graphs)):
g_list = batch_graphs[i]
# No valid candidate products have been predicted
if len(g_list) == 1:
continue
batch_reactant_graphs.append(g_list[0])
batch_product_graphs.extend(g_list[1:])
combo_scores_list.append(batch_combo_scores[i])
valid_candidate_combos_list.append(batch_valid_candidate_combos[i])
reactant_mols_list.append(batch_reactant_mols[i])
real_bond_changes_list.append(batch_real_bond_changes[i])
product_mols_list.append(batch_product_mols[i])
batch_num_candidate_products.append(len(g_list) - 1)
if len(batch_product_graphs) == 0:
return None, None, None, None, None, None, None, None
batch_reactant_graphs = dgl.batch(batch_reactant_graphs)
batch_product_graphs = dgl.batch(batch_product_graphs)
batch_combo_scores = torch.cat(combo_scores_list, dim=0)
return batch_reactant_graphs, batch_product_graphs, batch_combo_scores, \
valid_candidate_combos_list, reactant_mols_list, real_bond_changes_list, \
product_mols_list, batch_num_candidate_products
def sanitize_smiles_molvs(smiles, largest_fragment=False):
"""Sanitize a SMILES with MolVS
Parameters
----------
smiles : str
SMILES string for a molecule.
largest_fragment : bool
Whether to select only the largest covalent unit in a molecule with
multiple fragments. Default to False.
Returns
-------
str
SMILES string for the sanitized molecule.
"""
standardizer = Standardizer()
standardizer.prefer_organic = True
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return smiles
try:
mol = standardizer.standardize(mol) # standardize functional group reps
if largest_fragment:
mol = standardizer.largest_fragment(mol) # remove product counterions/salts/etc.
mol = standardizer.uncharge(mol) # neutralize, e.g., carboxylic acids
except Exception:
pass
return Chem.MolToSmiles(mol)
def bookkeep_reactant(mol):
"""Bookkeep bonds in the reactant.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants.
Returns
-------
pair_to_bond_type : dict
Mapping 2-tuples of atoms to bond type. 1, 2, 3, 1.5 are
separately for single, double, triple and aromatic bond.
"""
pair_to_bond_type = dict()
for bond in mol.GetBonds():
atom1, atom2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
atom1, atom2 = min(atom1, atom2), max(atom1, atom2)
type_val = bond.GetBondTypeAsDouble()
pair_to_bond_type[(atom1, atom2)] = type_val
return pair_to_bond_type
bond_change_to_type = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE,
3: Chem.rdchem.BondType.TRIPLE, 1.5: Chem.rdchem.BondType.AROMATIC}
clean_rxns_postsani = [
# two adjacent aromatic nitrogens should allow for H shift
AllChem.ReactionFromSmarts('[n;H1;+0:1]:[n;H0;+1:2]>>[n;H0;+0:1]:[n;H0;+0:2]'),
# two aromatic nitrogens separated by one should allow for H shift
AllChem.ReactionFromSmarts('[n;H1;+0:1]:[c:3]:[n;H0;+1:2]>>[n;H0;+0:1]:[*:3]:[n;H0;+0:2]'),
AllChem.ReactionFromSmarts('[#7;H0;+:1]-[O;H1;+0:2]>>[#7;H0;+:1]-[O;H0;-:2]'),
# neutralize C(=O)[O-]
AllChem.ReactionFromSmarts('[C;H0;+0:1](=[O;H0;+0:2])[O;H0;-1:3]>>[C;H0;+0:1](=[O;H0;+0:2])[O;H1;+0:3]'),
# turn neutral halogens into anions EXCEPT HCl
AllChem.ReactionFromSmarts('[I,Br,F;H1;D0;+0:1]>>[*;H0;-1:1]'),
# inexplicable nitrogen anion in reactants gets fixed in prods
AllChem.ReactionFromSmarts('[N;H0;-1:1]([C:2])[C:3]>>[N;H1;+0:1]([*:2])[*:3]'),
]
def edit_mol(rmol, bond_changes, keep_atom_map=False):
"""Simulate reaction via graph editing
Parameters
----------
rmol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants
bond_changes : list of 3-tuples
Each tuple is of form (atom1, atom2, change_type)
keep_atom_map : bool
Whether to keep atom mapping number. Default to False.
Returns
-------
pred_smiles : list of str
SMILES for the edited molecule
"""
new_mol = Chem.RWMol(rmol)
# Keep track of aromatic nitrogens, which might cause explicit hydrogen issues
aromatic_nitrogen_ids = set()
aromatic_carbonyl_adj_to_aromatic_nh = dict()
aromatic_carbondeg3_adj_to_aromatic_nh0 = dict()
for atom in new_mol.GetAtoms():
if atom.GetIsAromatic() and atom.GetSymbol() == 'N':
aromatic_nitrogen_ids.add(atom.GetIdx())
for nbr in atom.GetNeighbors():
if atom.GetNumExplicitHs() == 1 and nbr.GetSymbol() == 'C' and \
nbr.GetIsAromatic() and \
any(b.GetBondTypeAsDouble() == 2 for b in nbr.GetBonds()):
aromatic_carbonyl_adj_to_aromatic_nh[nbr.GetIdx()] = atom.GetIdx()
elif atom.GetNumExplicitHs() == 0 and nbr.GetSymbol() == 'C' and \
nbr.GetIsAromatic() and len(nbr.GetBonds()) == 3:
aromatic_carbondeg3_adj_to_aromatic_nh0[nbr.GetIdx()] = atom.GetIdx()
else:
atom.SetNumExplicitHs(0)
new_mol.UpdatePropertyCache()
for atom1_id, atom2_id, change_type in bond_changes:
bond = new_mol.GetBondBetweenAtoms(atom1_id, atom2_id)
atom1 = new_mol.GetAtomWithIdx(atom1_id)
atom2 = new_mol.GetAtomWithIdx(atom2_id)
if bond is not None:
new_mol.RemoveBond(atom1_id, atom2_id)
# Are we losing a bond on an aromatic nitrogen?
if bond.GetBondTypeAsDouble() == 1.0:
if atom1_id in aromatic_nitrogen_ids:
if atom1.GetTotalNumHs() == 0:
atom1.SetNumExplicitHs(1)
elif atom1.GetFormalCharge() == 1:
atom1.SetFormalCharge(0)
elif atom2_id in aromatic_nitrogen_ids:
if atom2.GetTotalNumHs() == 0:
atom2.SetNumExplicitHs(1)
elif atom2.GetFormalCharge() == 1:
atom2.SetFormalCharge(0)
# Are we losing a c=O bond on an aromatic ring?
# If so, remove H from adjacent nH if appropriate
if bond.GetBondTypeAsDouble() == 2.0:
both_aromatic_nh_ids = [
aromatic_carbonyl_adj_to_aromatic_nh.get(atom1_id, None),
aromatic_carbonyl_adj_to_aromatic_nh.get(atom2_id, None)
]
for aromatic_nh_id in both_aromatic_nh_ids:
if aromatic_nh_id is not None:
new_mol.GetAtomWithIdx(aromatic_nh_id).SetNumExplicitHs(0)
if change_type > 0:
new_mol.AddBond(atom1_id, atom2_id, bond_change_to_type[change_type])
# Special alkylation case?
if change_type == 1:
if atom1_id in aromatic_nitrogen_ids:
if atom1.GetTotalNumHs() == 1:
atom1.SetNumExplicitHs(0)
else:
atom1.SetFormalCharge(1)
elif atom2_id in aromatic_nitrogen_ids:
if atom2.GetTotalNumHs() == 1:
atom2.SetNumExplicitHs(0)
else:
atom2.SetFormalCharge(1)
# Are we getting a c=O bond on an aromatic ring?
# If so, add H to adjacent nH0 if appropriate
if change_type == 2:
both_aromatic_nh0_ids = [
aromatic_carbondeg3_adj_to_aromatic_nh0.get(atom1_id, None),
aromatic_carbondeg3_adj_to_aromatic_nh0.get(atom2_id, None)
]
for aromatic_nh0_id in both_aromatic_nh0_ids:
if aromatic_nh0_id is not None:
new_mol.GetAtomWithIdx(aromatic_nh0_id).SetNumExplicitHs(1)
pred_mol = new_mol.GetMol()
# Clear formal charges to make molecules valid
# Note: because S and P (among others) can change valence, be more flexible
for atom in pred_mol.GetAtoms():
if not keep_atom_map:
atom.ClearProp('molAtomMapNumber')
if atom.GetSymbol() == 'N' and atom.GetFormalCharge() == 1:
# exclude negatively-charged azide
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals <= 3:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'N' and atom.GetFormalCharge() == -1:
# handle negatively-charged azide addition
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 3 and any([nbr.GetSymbol() == 'N' for nbr in atom.GetNeighbors()]):
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'N':
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 4 and not atom.GetIsAromatic():
atom.SetFormalCharge(1)
elif atom.GetSymbol() == 'C' and atom.GetFormalCharge() != 0:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'O' and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) + \
atom.GetNumExplicitHs()
if bond_vals == 2:
atom.SetFormalCharge(0)
elif atom.GetSymbol() in ['Cl', 'Br', 'I', 'F'] and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 1:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'S' and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals in [2, 4, 6]:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'P':
# quartenary phosphorous should be pos. charge with 0 H
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 4 and len(bond_vals) == 4:
atom.SetFormalCharge(1)
atom.SetNumExplicitHs(0)
elif sum(bond_vals) == 3 and len(bond_vals) == 3:
# make sure neutral
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'B':
# quartenary boron should be neg. charge with 0 H
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 4 and len(bond_vals) == 4:
atom.SetFormalCharge(-1)
atom.SetNumExplicitHs(0)
elif atom.GetSymbol() in ['Mg', 'Zn']:
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 1 and len(bond_vals) == 1:
atom.SetFormalCharge(1)
elif atom.GetSymbol() == 'Si':
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == len(bond_vals):
atom.SetNumExplicitHs(max(0, 4 - len(bond_vals)))
# Bounce to/from SMILES to try to sanitize
pred_smiles = Chem.MolToSmiles(pred_mol)
pred_list = pred_smiles.split('.')
pred_mols = [Chem.MolFromSmiles(pred_smiles) for pred_smiles in pred_list]
for i, mol in enumerate(pred_mols):
if mol is None:
continue
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
if mol is None:
continue
for rxn in clean_rxns_postsani:
out = rxn.RunReactants((mol,))
if out:
try:
Chem.SanitizeMol(out[0][0])
pred_mols[i] = Chem.MolFromSmiles(Chem.MolToSmiles(out[0][0]))
except Exception as e:
pass
pred_smiles = [Chem.MolToSmiles(pred_mol) for pred_mol in pred_mols if pred_mol is not None]
return pred_smiles
def examine_topk_candidate_product(topks, topk_combos, reactant_mol,
real_bond_changes, product_mol):
"""Perform topk evaluation for predicting the product of a reaction
Parameters
----------
topks : list of int
Options for top-k evaluation, e.g. [1, 3, ...].
topk_combos : list of list
topk_combos[i] gives the combo of valid bond changes ranked i-th,
which is a list of 3-tuples. Each tuple is of form
(atom1, atom2, change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. The change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond or
forming a single, double, triple, aromatic bond.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants.
real_bond_changes : list of tuples
Ground truth bond changes in a reaction. Each tuple is of form (atom1, atom2,
change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product.
get_smiles : bool
Whether to get the SMILES of candidate products.
Returns
-------
found_info : dict
Binary values indicating whether we can recover the product from the ground truth
graph edits or top-k predicted edits
"""
found_info = defaultdict(bool)
# Avoid corrupting the RDKit molecule instances in the dataset
reactant_mol = deepcopy(reactant_mol)
product_mol = deepcopy(product_mol)
for atom in product_mol.GetAtoms():
atom.ClearProp('molAtomMapNumber')
product_smiles = Chem.MolToSmiles(product_mol)
product_smiles_sanitized = set(sanitize_smiles_molvs(product_smiles, True).split('.'))
product_smiles = set(product_smiles.split('.'))
########### Use *true* edits to try to recover product
# Generate product by modifying reactants with graph edits
pred_smiles = edit_mol(reactant_mol, real_bond_changes)
pred_smiles_sanitized = set(sanitize_smiles_molvs(smiles) for smiles in pred_smiles)
pred_smiles = set(pred_smiles)
if not product_smiles <= pred_smiles:
# Try again with kekulized form
Chem.Kekulize(reactant_mol)
pred_smiles_kek = edit_mol(reactant_mol, real_bond_changes)
pred_smiles_kek = set(pred_smiles_kek)
if not product_smiles <= pred_smiles_kek:
if product_smiles_sanitized <= pred_smiles_sanitized:
print('\nwarn: mismatch, but only due to standardization')
found_info['ground_sanitized'] = True
else:
print('\nwarn: could not regenerate product {}'.format(product_smiles))
print('sani product: {}'.format(product_smiles_sanitized))
print(Chem.MolToSmiles(reactant_mol))
print(Chem.MolToSmiles(product_mol))
print(real_bond_changes)
print('pred_smiles: {}'.format(pred_smiles))
print('pred_smiles_kek: {}'.format(pred_smiles_kek))
print('pred_smiles_sani: {}'.format(pred_smiles_sanitized))
else:
found_info['ground'] = True
found_info['ground_sanitized'] = True
else:
found_info['ground'] = True
found_info['ground_sanitized'] = True
########### Now use candidate edits to try to recover product
max_topk = max(topks)
current_rank = 0
correct_rank = max_topk + 1
sanitized_correct_rank = max_topk + 1
candidate_smiles_list = []
candidate_smiles_sanitized_list = []
for i, combo in enumerate(topk_combos):
prev_len_candidate_smiles = len(set(candidate_smiles_list))
# Generate products by modifying reactants with predicted edits.
candidate_smiles = edit_mol(reactant_mol, combo)
candidate_smiles = set(candidate_smiles)
candidate_smiles_sanitized = set(sanitize_smiles_molvs(smiles)
for smiles in candidate_smiles)
if product_smiles_sanitized <= candidate_smiles_sanitized:
sanitized_correct_rank = min(sanitized_correct_rank, current_rank + 1)
if product_smiles <= candidate_smiles:
correct_rank = min(correct_rank, current_rank + 1)
# Record unkekulized form
candidate_smiles_list.append('.'.join(candidate_smiles))
candidate_smiles_sanitized_list.append('.'.join(candidate_smiles_sanitized))
# Edit molecules with reactants kekulized. Sometimes previous editing fails due to
# RDKit sanitization error (edited molecule cannot be kekulized)
try:
Chem.Kekulize(reactant_mol)
except Exception as e:
pass
candidate_smiles = edit_mol(reactant_mol, combo)
candidate_smiles = set(candidate_smiles)
candidate_smiles_sanitized = set(sanitize_smiles_molvs(smiles)
for smiles in candidate_smiles)
if product_smiles_sanitized <= candidate_smiles_sanitized:
sanitized_correct_rank = min(sanitized_correct_rank, current_rank + 1)
if product_smiles <= candidate_smiles:
correct_rank = min(correct_rank, current_rank + 1)
# If we failed to come up with a new candidate, don't increment the counter!
if len(set(candidate_smiles_list)) > prev_len_candidate_smiles:
current_rank += 1
if correct_rank < max_topk + 1 and sanitized_correct_rank < max_topk + 1:
break
for k in topks:
if correct_rank <= k:
found_info['top_{:d}'.format(k)] = True
if sanitized_correct_rank <= k:
found_info['top_{:d}_sanitized'.format(k)] = True
return found_info
def summary_candidate_ranking_info(top_ks, found_info, data_size):
"""Get a string for summarizing the candidate ranking results
Parameters
----------
top_ks : list of int
Options for top-k evaluation, e.g. [1, 3, ...].
found_info : dict
Storing the count of correct predictions
data_size : int
Size for the dataset
Returns
-------
string : str
String summarizing the evaluation results
"""
string = '[strict]'
for k in top_ks:
string += ' acc@{:d}: {:.4f}'.format(k, found_info['top_{:d}'.format(k)] / data_size)
string += ' gfound {:.4f}\n'.format(found_info['ground'] / data_size)
string += '[molvs]'
for k in top_ks:
string += ' acc@{:d}: {:.4f}'.format(
k, found_info['top_{:d}_sanitized'.format(k)] / data_size)
string += ' gfound {:.4f}\n'.format(found_info['ground_sanitized'] / data_size)
return string
def candidate_ranking_eval(args, model, data_loader):
"""Evaluate model performance on candidate ranking.
Parameters
----------
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
Returns
-------
str
String summarizing the evaluation results
"""
model.eval()
# Record how many product can be recovered by real graph edits (with/without sanitization)
found_info_summary = {'ground': 0, 'ground_sanitized': 0}
for k in args['top_ks']:
found_info_summary['top_{:d}'.format(k)] = 0
found_info_summary['top_{:d}_sanitized'.format(k)] = 0
total_samples = 0
for batch_id, batch_data in enumerate(data_loader):
batch_reactant_graphs, batch_product_graphs, batch_combo_scores, \
batch_valid_candidate_combos, batch_reactant_mols, batch_real_bond_changes, \
batch_product_mols, batch_num_candidate_products = batch_data
# No valid candidate products have been predicted
if batch_reactant_graphs is None:
continue
total_samples += len(batch_num_candidate_products)
batch_combo_scores = batch_combo_scores.to(args['device'])
reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(args['device'])
reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(args['device'])
product_node_feats = batch_product_graphs.ndata.pop('hv').to(args['device'])
product_edge_feats = batch_product_graphs.edata.pop('he').to(args['device'])
# Get candidate products with top-k ranking
with torch.no_grad():
pred = model(reactant_graph=batch_reactant_graphs,
reactant_node_feats=reactant_node_feats,
reactant_edge_feats=reactant_edge_feats,
product_graphs=batch_product_graphs,
product_node_feats=product_node_feats,
product_edge_feats=product_edge_feats,
candidate_scores=batch_combo_scores,
batch_num_candidate_products=batch_num_candidate_products)
product_graph_start = 0
for i in range(len(batch_num_candidate_products)):
num_candidate_products = batch_num_candidate_products[i]
reactant_mol = batch_reactant_mols[i]
valid_candidate_combos = batch_valid_candidate_combos[i]
real_bond_changes = batch_real_bond_changes[i]
product_mol = batch_product_mols[i]
product_graph_end = product_graph_start + num_candidate_products
top_k = min(args['max_k'], num_candidate_products)
reaction_pred = pred[product_graph_start:product_graph_end, :]
topk_values, topk_indices = torch.topk(reaction_pred, top_k, dim=0)
# Filter out invalid candidate bond changes
reactant_pair_to_bond = bookkeep_reactant(reactant_mol)
topk_combos = []
for i in topk_indices:
i = i.detach().cpu().item()
combo = []
for atom1, atom2, change_type, score in valid_candidate_combos[i]:
bond_in_reactant = reactant_pair_to_bond.get((atom1, atom2), None)
if (bond_in_reactant is None and change_type > 0) or \
(bond_in_reactant is not None and bond_in_reactant != change_type):
combo.append((atom1, atom2, change_type))
topk_combos.append(combo)
batch_found_info = examine_topk_candidate_product(
args['top_ks'], topk_combos, reactant_mol, real_bond_changes, product_mol)
for k, v in batch_found_info.items():
found_info_summary[k] += float(v)
product_graph_start = product_graph_end
if total_samples % args['print_every'] == 0:
print('Iter {:d}/{:d}'.format(
total_samples // args['print_every'],
len(data_loader.dataset) // args['print_every']))
print(summary_candidate_ranking_info(
args['top_ks'], found_info_summary, total_samples))
return summary_candidate_ranking_info(args['top_ks'], found_info_summary, total_samples)
"""USPTO for reaction prediction"""
import errno
import numpy as np
import os
import random
import torch
from collections import defaultdict
from copy import deepcopy
from dgl import DGLGraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive, \
save_graphs, load_graphs
from functools import partial
from itertools import combinations
from multiprocessing import Pool
from rdkit import Chem, RDLogger
from rdkit.Chem import rdmolops
from tqdm import tqdm
from ..utils.featurizers import BaseAtomFeaturizer, ConcatFeaturizer, atom_type_one_hot, \
atom_degree_one_hot, atom_explicit_valence_one_hot, atom_implicit_valence_one_hot, \
atom_is_aromatic, BaseBondFeaturizer, bond_type_one_hot, bond_is_conjugated, bond_is_in_ring
from ..utils.featurizers import BaseAtomFeaturizer, ConcatFeaturizer, one_hot_encoding, \
atom_type_one_hot, atom_degree_one_hot, atom_explicit_valence_one_hot, \
atom_implicit_valence_one_hot, atom_is_aromatic, atom_formal_charge_one_hot, \
BaseBondFeaturizer, bond_type_one_hot, bond_is_conjugated, bond_is_in_ring
from ..utils.mol_to_graph import mol_to_bigraph, mol_to_complete_graph
__all__ = ['WLNReactionDataset',
'USPTO']
__all__ = ['WLNCenterDataset',
'USPTOCenter',
'WLNRankDataset',
'USPTORank']
# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')
......@@ -29,22 +38,48 @@ atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi',
'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs']
default_node_featurizer = BaseAtomFeaturizer({
default_node_featurizer_center = BaseAtomFeaturizer({
'hv': ConcatFeaturizer(
[partial(atom_type_one_hot, allowable_set=atom_types, encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_explicit_valence_one_hot,
partial(atom_implicit_valence_one_hot, allowable_set=list(range(6))),
[partial(atom_type_one_hot,
allowable_set=atom_types, encode_unknown=True),
partial(atom_degree_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
partial(atom_explicit_valence_one_hot,
allowable_set=list(range(1, 6)), encode_unknown=True),
partial(atom_implicit_valence_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
atom_is_aromatic]
)
})
default_edge_featurizer = BaseBondFeaturizer({
default_node_featurizer_rank = BaseAtomFeaturizer({
'hv': ConcatFeaturizer(
[partial(atom_type_one_hot,
allowable_set=atom_types, encode_unknown=True),
partial(atom_formal_charge_one_hot,
allowable_set=[-3, -2, -1, 0, 1, 2], encode_unknown=True),
partial(atom_degree_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
partial(atom_explicit_valence_one_hot,
allowable_set=list(range(1, 6)), encode_unknown=True),
partial(atom_implicit_valence_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
atom_is_aromatic]
)
})
default_edge_featurizer_center = BaseBondFeaturizer({
'he': ConcatFeaturizer([
bond_type_one_hot, bond_is_conjugated, bond_is_in_ring]
)
})
default_edge_featurizer_rank = BaseBondFeaturizer({
'he': ConcatFeaturizer([
bond_type_one_hot, bond_is_in_ring]
)
})
def default_atom_pair_featurizer(reactants):
"""Featurize each pair of atoms, which will be used in updating
the edata of a complete DGLGraph.
......@@ -198,25 +233,100 @@ def get_bond_changes(reaction):
return bond_changes
def process_file(path):
def process_line(line):
"""Process one line consisting of one reaction for working with WLN.
Parameters
----------
line : str
One reaction in one line
Returns
-------
formatted_reaction : str
Formatted reaction
"""
reaction = line.strip()
bond_changes = get_bond_changes(reaction)
formatted_reaction = '{} {}\n'.format(
reaction, ';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes]))
return formatted_reaction
def process_file(path, num_processes=1):
"""Pre-process a file of reactions for working with WLN.
Parameters
----------
path : str
Path to the file of reactions
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
with open(path, 'r') as input_file, open(path + '.proc', 'w') as output_file:
for line in tqdm(input_file):
reaction = line.strip()
bond_changes = get_bond_changes(reaction)
output_file.write('{} {}\n'.format(
reaction,
';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes])))
with open(path, 'r') as input_file:
lines = input_file.readlines()
if num_processes == 1:
results = []
for li in lines:
results.append(process_line(li))
else:
with Pool(processes=num_processes) as pool:
results = pool.map(process_line, lines)
with open(path + '.proc', 'w') as output_file:
for line in results:
output_file.write(line)
print('Finished processing {}'.format(path))
class WLNReactionDataset(object):
"""Dataset for reaction prediction with WLN
def load_one_reaction(line):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
reactants are not valid.
reaction : str
Reaction. None will be returned if the reactants are not valid.
graph_edits : str
Graph edits associated with the reaction. None will be returned if the
reactants are not valid.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants = reaction.split('>')[0]
mol = Chem.MolFromSmiles(reactants)
if mol is None:
return None, None, None
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(mol.GetNumAtoms())]
for j in range(mol.GetNumAtoms()):
atom = mol.GetAtomWithIdx(j)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = j
mol = rdmolops.RenumberAtoms(mol, atom_map_order)
return mol, reaction, graph_edits
class WLNCenterDataset(object):
"""Dataset for reaction center prediction with WLN
Parameters
----------
......@@ -246,16 +356,19 @@ class WLNReactionDataset(object):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
raw_file_path,
mol_graph_path,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer,
edge_featurizer=default_edge_featurizer,
node_featurizer=default_node_featurizer_center,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True):
super(WLNReactionDataset, self).__init__()
load=True,
num_processes=1):
super(WLNCenterDataset, self).__init__()
self._atom_pair_featurizer = atom_pair_featurizer
self.atom_pair_features = []
......@@ -265,23 +378,33 @@ class WLNReactionDataset(object):
path_to_reaction_file = raw_file_path + '.proc'
if not os.path.isfile(path_to_reaction_file):
# Pre-process graph edits information
process_file(raw_file_path)
print('Pre-processing graph edits from reaction data')
process_file(raw_file_path, num_processes)
import time
t0 = time.time()
full_mols, full_reactions, full_graph_edits = \
self.load_reaction_data(path_to_reaction_file)
self.load_reaction_data(path_to_reaction_file, num_processes)
print('Time spent', time.time() - t0)
if load and os.path.isfile(mol_graph_path):
print('Loading previously saved graphs...')
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
self.reactant_mol_graphs = []
for i in range(len(full_mols)):
if i % 10000 == 0:
print('Processing reaction {:d}/{:d}'.format(i + 1, len(full_mols)))
mol = full_mols[i]
reactant_mol_graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
canonical_atom_order=False)
self.reactant_mol_graphs.append(reactant_mol_graph)
print('Constructing graphs from scratch...')
if num_processes == 1:
self.reactant_mol_graphs = []
for mol in full_mols:
self.reactant_mol_graphs.append(mol_to_graph(
mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False))
else:
torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(processes=num_processes) as pool:
self.reactant_mol_graphs = pool.map(
partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False),
full_mols)
save_graphs(mol_graph_path, self.reactant_mol_graphs)
......@@ -291,13 +414,15 @@ class WLNReactionDataset(object):
self.atom_pair_features.extend([None for _ in range(len(self.mols))])
self.atom_pair_labels.extend([None for _ in range(len(self.mols))])
def load_reaction_data(self, file_path):
def load_reaction_data(self, file_path, num_processes):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
......@@ -312,38 +437,22 @@ class WLNReactionDataset(object):
all_reactions = []
all_graph_edits = []
with open(file_path, 'r') as f:
for i, line in enumerate(f):
if i % 10000 == 0:
print('Processing line {:d}'.format(i))
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants = reaction.split('>')[0]
mol = Chem.MolFromSmiles(reactants)
if mol is None:
continue
lines = f.readlines()
if num_processes == 1:
results = []
for li in lines:
mol, reaction, graph_edits = load_one_reaction(li)
results.append((mol, reaction, graph_edits))
else:
with Pool(processes=num_processes) as pool:
results = pool.map(load_one_reaction, lines)
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(mol.GetNumAtoms())]
for i in range(mol.GetNumAtoms()):
atom = mol.GetAtomWithIdx(i)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = i
mol = rdmolops.RenumberAtoms(mol, atom_map_order)
all_mols.append(mol)
all_reactions.append(reaction)
all_graph_edits.append(graph_edits)
for mol, reaction, graph_edits in results:
if mol is None:
continue
all_mols.append(mol)
all_reactions.append(reaction)
all_graph_edits.append(graph_edits)
return all_mols, all_reactions, all_graph_edits
......@@ -363,15 +472,13 @@ class WLNReactionDataset(object):
Returns
-------
str
Reaction.
Reaction
str
Graph edits for the reaction
rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants
DGLGraph
DGLGraph for the ith molecular graph of reactants
DGLGraph for the ith molecular graph
DGLGraph
Complete DGLGraph for reactants, which will be needed for predicting
Complete DGLGraph, which will be needed for predicting
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
......@@ -384,7 +491,7 @@ class WLNReactionDataset(object):
if num_atoms not in self.complete_graphs:
self.complete_graphs[num_atoms] = mol_to_complete_graph(
mol, add_self_loop=True, canonical_atom_order=True)
mol, add_self_loop=True, canonical_atom_order=False)
if self.atom_pair_features[item] is None:
reactants = self.reactions[item].split('>')[0]
......@@ -393,14 +500,14 @@ class WLNReactionDataset(object):
if self.atom_pair_labels[item] is None:
self.atom_pair_labels[item] = get_pair_label(mol, self.graph_edits[item])
return self.reactions[item], self.graph_edits[item], mol, \
return self.reactions[item], self.graph_edits[item], \
self.reactant_mol_graphs[item], \
self.complete_graphs[num_atoms], \
self.atom_pair_features[item], \
self.atom_pair_labels[item]
class USPTO(WLNReactionDataset):
"""USPTO dataset for reaction prediction.
class USPTOCenter(WLNCenterDataset):
"""USPTO dataset for reaction center prediction.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
......@@ -439,17 +546,20 @@ class USPTO(WLNReactionDataset):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
subset,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer,
edge_featurizer=default_edge_featurizer,
node_featurizer=default_node_featurizer_center,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True):
load=True,
num_processes=1):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO'.format(subset))
print('Preparing {} subset of USPTO for reaction center prediction.'.format(subset))
self._subset = subset
if subset == 'val':
subset = 'valid'
......@@ -460,23 +570,993 @@ class USPTO(WLNReactionDataset):
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
super(USPTO, self).__init__(
super(USPTOCenter, self).__init__(
raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
mol_graph_path=extracted_data_path + '/{}_mol_graphs.bin'.format(subset),
mol_to_graph=mol_to_graph,
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
atom_pair_featurizer=atom_pair_featurizer,
load=load)
load=load,
num_processes=num_processes)
@property
def subset(self):
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return self._subset
def mkdir_p(path):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def load_one_reaction_rank(line):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
line is not valid.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product. None will be returned if the line is not valid.
reaction_real_bond_changes : list of 3-tuples
Real bond changes in the reaction. Each tuple is of form (atom1, atom2, change_type). For
change_type, 0.0 stands for losing a bond, 1.0, 2.0, 3.0 and 1.5 separately stands for
forming a single, double, triple or aromatic bond.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants, _, product = reaction.split('>')
reactants_mol = Chem.MolFromSmiles(reactants)
if reactants_mol is None:
return None, None, None, None, None
product_mol = Chem.MolFromSmiles(product)
if product_mol is None:
return None, None, None, None, None
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(reactants_mol.GetNumAtoms())]
for j in range(reactants_mol.GetNumAtoms()):
atom = reactants_mol.GetAtomWithIdx(j)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = j
reactants_mol = rdmolops.RenumberAtoms(reactants_mol, atom_map_order)
reaction_real_bond_changes = []
for changed_bond in graph_edits.split(';'):
atom1, atom2, change_type = changed_bond.split('-')
atom1, atom2 = int(atom1) - 1, int(atom2) - 1
reaction_real_bond_changes.append(
(min(atom1, atom2), max(atom1, atom2), float(change_type)))
return reactants_mol, product_mol, reaction_real_bond_changes
def load_candidate_bond_changes_for_one_reaction(line):
"""Load candidate bond changes for a reaction
Parameters
----------
line : str
Candidate bond changes separated by ;. Each candidate bond change takes the
form of atom1, atom2, change_type and change_score.
Returns
-------
list of 4-tuples
Loaded candidate bond changes.
"""
reaction_candidate_bond_changes = []
elements = line.strip().split(';')[:-1]
for candidate in elements:
atom1, atom2, change_type, score = candidate.split(' ')
atom1, atom2 = int(atom1) - 1, int(atom2) - 1
reaction_candidate_bond_changes.append((
min(atom1, atom2), max(atom1, atom2), float(change_type), float(score)))
return reaction_candidate_bond_changes
def bookkeep_reactant(mol, candidate_pairs):
"""Bookkeep reaction-related information of reactants.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants.
candidate_pairs : list of 2-tuples
Pairs of atoms that ranked high by a model for reaction center prediction.
By assumption, the two atoms are different and the first atom has a smaller
index than the second.
Returns
-------
info : dict
Reaction-related information of reactants
"""
num_atoms = mol.GetNumAtoms()
info = {
# free valence of atoms
'free_val': [0 for _ in range(num_atoms)],
# Whether it is a carbon atom
'is_c': [False for _ in range(num_atoms)],
# Whether it is a carbon atom connected to a nitrogen atom in pyridine
'is_c2_of_pyridine': [False for _ in range(num_atoms)],
# Whether it is a phosphorous atom
'is_p': [False for _ in range(num_atoms)],
# Whether it is a sulfur atom
'is_s': [False for _ in range(num_atoms)],
# Whether it is an oxygen atom
'is_o': [False for _ in range(num_atoms)],
# Whether it is a nitrogen atom
'is_n': [False for _ in range(num_atoms)],
'pair_to_bond_val': dict(),
'ring_bonds': set()
}
# bookkeep atoms
for j, atom in enumerate(mol.GetAtoms()):
info['free_val'][j] += atom.GetTotalNumHs() + abs(atom.GetFormalCharge())
# An aromatic carbon atom next to an aromatic nitrogen atom can get a
# carbonyl b/c of bookkeeping of hydroxypyridines
if atom.GetSymbol() == 'C':
info['is_c'][j] = True
if atom.GetIsAromatic():
for nbr in atom.GetNeighbors():
if nbr.GetSymbol() == 'N' and nbr.GetDegree() == 2:
info['is_c2_of_pyridine'][j] = True
break
# A nitrogen atom should be allowed to become positively charged
elif atom.GetSymbol() == 'N':
info['free_val'][j] += 1 - atom.GetFormalCharge()
info['is_n'][j] = True
# Phosphorous atoms can form a phosphonium
elif atom.GetSymbol() == 'P':
info['free_val'][j] += 1 - atom.GetFormalCharge()
info['is_p'][j] = True
elif atom.GetSymbol() == 'O':
info['is_o'][j] = True
elif atom.GetSymbol() == 'S':
info['is_s'][j] = True
# bookkeep bonds
for bond in mol.GetBonds():
atom1, atom2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
atom1, atom2 = min(atom1, atom2), max(atom1, atom2)
type_val = bond.GetBondTypeAsDouble()
info['pair_to_bond_val'][(atom1, atom2)] = type_val
if (atom1, atom2) in candidate_pairs:
info['free_val'][atom1] += type_val
info['free_val'][atom2] += type_val
if bond.IsInRing():
info['ring_bonds'].add((atom1, atom2))
return info
def bookkeep_product(mol):
"""Bookkeep reaction-related information of atoms/bonds in products
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for products.
Returns
-------
info : dict
Reaction-related information of atoms/bonds in products
"""
info = {
'atoms': set()
}
for atom in mol.GetAtoms():
info['atoms'].add(atom.GetAtomMapNum() - 1)
return info
def is_connected_change_combo(combo_ids, cand_change_adj):
"""Check whether the combo of bond changes yields a connected component.
Parameters
----------
combo_ids : tuple of int
Ids for bond changes in the combination.
cand_change_adj : bool ndarray of shape (N, N)
Adjacency matrix for candidate bond changes. Two candidate bond
changes are considered adjacent if they share a common atom.
* N for the number of candidate bond changes.
Returns
-------
bool
Whether the combo of bond changes yields a connected component
"""
if len(combo_ids) == 1:
return True
multi_hop_adj = np.linalg.matrix_power(
cand_change_adj[combo_ids, :][:, combo_ids], len(combo_ids) - 1)
# The combo is connected if the distance between
# any pair of bond changes is within len(combo) - 1
return np.all(multi_hop_adj)
def is_valid_combo(combo_changes, reactant_info):
"""Whether the combo of bond changes is chemically valid.
Parameters
----------
combo_changes : list of 4-tuples
Each tuple consists of atom1, atom2, type of bond change (in the form of related
valence) and score for the change.
reactant_info : dict
Reaction-related information of reactants
Returns
-------
bool
Whether the combo of bond changes is chemically valid.
"""
num_atoms = len(reactant_info['free_val'])
force_even_parity = np.zeros((num_atoms,), dtype=bool)
force_odd_parity = np.zeros((num_atoms,), dtype=bool)
pair_seen = defaultdict(bool)
free_val_tmp = reactant_info['free_val'].copy()
for (atom1, atom2, change_type, score) in combo_changes:
if pair_seen[(atom1, atom2)]:
# A pair of atoms cannot have two types of changes. Even if we
# randomly pick one, that will be reduced to a combo of less changes
return False
pair_seen[(atom1, atom2)] = True
# Special valence rules
atom1_type_val = atom2_type_val = change_type
if change_type == 2:
# to form a double bond
if reactant_info['is_o'][atom1]:
if reactant_info['is_c2_of_pyridine'][atom2]:
atom2_type_val = 1.
elif reactant_info['is_p'][atom2]:
# don't count information of =o toward valence
# but require odd valence parity
atom2_type_val = 0.
force_odd_parity[atom2] = True
elif reactant_info['is_s'][atom2]:
atom2_type_val = 0.
force_even_parity[atom2] = True
elif reactant_info['is_o'][atom2]:
if reactant_info['is_c2_of_pyridine'][atom1]:
atom1_type_val = 1.
elif reactant_info['is_p'][atom1]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_s'][atom1]:
atom1_type_val = 0.
force_even_parity[atom1] = True
elif reactant_info['is_n'][atom1] and reactant_info['is_p'][atom2]:
atom2_type_val = 0.
force_odd_parity[atom2] = True
elif reactant_info['is_n'][atom2] and reactant_info['is_p'][atom1]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_p'][atom1] and reactant_info['is_c'][atom2]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_p'][atom2] and reactant_info['is_c'][atom1]:
atom2_type_val = 0.
force_odd_parity[atom2] = True
reactant_pair_val = reactant_info['pair_to_bond_val'].get((atom1, atom2), None)
if reactant_pair_val is not None:
free_val_tmp[atom1] += reactant_pair_val - atom1_type_val
free_val_tmp[atom2] += reactant_pair_val - atom2_type_val
else:
free_val_tmp[atom1] -= atom1_type_val
free_val_tmp[atom2] -= atom2_type_val
free_val_tmp = np.array(free_val_tmp)
# False if 1) too many connections 2) sulfur valence not even
# 3) phosphorous valence not odd
if any(free_val_tmp < 0) or \
any(aval % 2 != 0 for aval in free_val_tmp[force_even_parity]) or \
any(aval % 2 != 1 for aval in free_val_tmp[force_odd_parity]):
return False
return True
def edit_mol(reactant_mols, edits, product_info):
"""Simulate reaction via graph editing
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
bond_change_to_type = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE,
3: Chem.rdchem.BondType.TRIPLE, 1.5: Chem.rdchem.BondType.AROMATIC}
new_mol = Chem.RWMol(reactant_mols)
[atom.SetNumExplicitHs(0) for atom in new_mol.GetAtoms()]
for atom1, atom2, change_type, score in edits:
bond = new_mol.GetBondBetweenAtoms(atom1, atom2)
if bond is not None:
new_mol.RemoveBond(atom1, atom2)
if change_type > 0:
new_mol.AddBond(atom1, atom2, bond_change_to_type[change_type])
pred_mol = new_mol.GetMol()
pred_smiles = Chem.MolToSmiles(pred_mol)
pred_list = pred_smiles.split('.')
pred_mols = []
for pred_smiles in pred_list:
mol = Chem.MolFromSmiles(pred_smiles)
if mol is None:
continue
atom_set = set([atom.GetAtomMapNum() - 1 for atom in mol.GetAtoms()])
if len(atom_set & product_info['atoms']) == 0:
continue
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
pred_mols.append(mol)
return '.'.join(sorted([Chem.MolToSmiles(mol) for mol in pred_mols]))
def get_product_smiles(reactant_mols, edits, product_info):
"""Get the product smiles of the reaction
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
smiles = edit_mol(reactant_mols, edits, product_info)
if len(smiles) != 0:
return smiles
try:
Chem.Kekulize(reactant_mols)
except Exception as e:
return smiles
return edit_mol(reactant_mols, edits, product_info)
def generate_valid_candidate_combos():
return NotImplementedError
def pre_process_one_reaction(info, num_candidate_bond_changes, max_num_bond_changes,
max_num_change_combos, mode):
"""Pre-process one reaction for candidate ranking.
Parameters
----------
info : 4-tuple
* candidate_bond_changes : list of tuples
The candidate bond changes for the reaction
* real_bond_changes : list of tuples
The real bond changes for the reaction
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants
* product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for product
num_candidate_bond_changes : int
Number of candidate bond changes to consider for the ground truth reaction.
max_num_bond_changes : int
Maximum number of bond changes per reaction.
max_num_change_combos : int
Number of bond change combos to consider for each reaction.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
mode : str
Whether the dataset is to be used for training, validation or test.
Returns
-------
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for combos.
reactant_info : dict
Reaction-related information of reactants.
"""
assert mode in ['train', 'val', 'test'], \
"Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
candidate_bond_changes_, real_bond_changes, reactant_mol, product_mol = info
candidate_pairs = [(atom1, atom2) for (atom1, atom2, _, _)
in candidate_bond_changes_]
reactant_info = bookkeep_reactant(reactant_mol, candidate_pairs)
if mode == 'train':
product_info = bookkeep_product(product_mol)
# Filter out candidate new bonds already in reactants
candidate_bond_changes = []
count = 0
for (atom1, atom2, change_type, score) in candidate_bond_changes_:
if ((atom1, atom2) not in reactant_info['pair_to_bond_val']) or \
(reactant_info['pair_to_bond_val'][(atom1, atom2)] != change_type):
candidate_bond_changes.append((atom1, atom2, change_type, score))
count += 1
if count == num_candidate_bond_changes:
break
# Check if two bond changes have atom in common
cand_change_adj = np.eye(len(candidate_bond_changes), dtype=bool)
for i in range(len(candidate_bond_changes)):
atom1_1, atom1_2, _, _ = candidate_bond_changes[i]
for j in range(i + 1, len(candidate_bond_changes)):
atom2_1, atom2_2, _, _ = candidate_bond_changes[j]
if atom1_1 == atom2_1 or atom1_1 == atom2_2 or \
atom1_2 == atom2_1 or atom1_2 == atom2_2:
cand_change_adj[i, j] = cand_change_adj[j, i] = True
# Enumerate combinations of k candidate bond changes and record
# those that are connected and chemically valid
valid_candidate_combos = []
cand_change_ids = range(len(candidate_bond_changes))
for k in range(1, max_num_bond_changes + 1):
for combo_ids in combinations(cand_change_ids, k):
# Check if the changed bonds form a connected component
if not is_connected_change_combo(combo_ids, cand_change_adj):
continue
combo_changes = [candidate_bond_changes[j] for j in combo_ids]
# Check if the combo is chemically valid
if is_valid_combo(combo_changes, reactant_info):
valid_candidate_combos.append(combo_changes)
if mode == 'train':
random.shuffle(valid_candidate_combos)
# Index for the combo of candidate bond changes
# that is equivalent to the gold combo
real_combo_id = -1
for j, combo_changes in enumerate(valid_candidate_combos):
if set([(atom1, atom2, change_type) for
(atom1, atom2, change_type, score) in combo_changes]) == \
set(real_bond_changes):
real_combo_id = j
break
# If we fail to find the real combo, make it the first entry
if real_combo_id == -1:
valid_candidate_combos = \
[[(atom1, atom2, change_type, 0.0)
for (atom1, atom2, change_type) in real_bond_changes]] + \
valid_candidate_combos
else:
valid_candidate_combos[0], valid_candidate_combos[real_combo_id] = \
valid_candidate_combos[real_combo_id], valid_candidate_combos[0]
product_smiles = get_product_smiles(
reactant_mol, valid_candidate_combos[0], product_info)
if len(product_smiles) > 0:
# Remove combos yielding duplicate products
product_smiles = set([product_smiles])
new_candidate_combos = [valid_candidate_combos[0]]
count = 0
for combo in valid_candidate_combos[1:]:
smiles = get_product_smiles(reactant_mol, combo, product_info)
if smiles in product_smiles or len(smiles) == 0:
continue
product_smiles.add(smiles)
new_candidate_combos.append(combo)
count += 1
if count == max_num_change_combos:
break
valid_candidate_combos = new_candidate_combos
valid_candidate_combos = valid_candidate_combos[:max_num_change_combos]
return valid_candidate_combos, candidate_bond_changes, reactant_info
def featurize_nodes_and_compute_combo_scores(
node_featurizer, reactant_mol, valid_candidate_combos):
"""Featurize atoms in reactants and compute scores for combos of bond changes
Parameters
----------
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
Returns
-------
node_feats : float32 tensor of shape (N, M)
Node features for reactants, N for the number of nodes and M for the feature size
combo_bias : float32 tensor of shape (B, 1)
Scores for combos of bond changes, B equals len(valid_candidate_combos)
"""
node_feats = node_featurizer(reactant_mol)['hv']
combo_bias = torch.zeros(len(valid_candidate_combos), 1).float()
for combo_id, combo in enumerate(valid_candidate_combos):
combo_bias[combo_id] = sum([
score for (atom1, atom2, change_type, score) in combo])
return node_feats, combo_bias
def construct_graphs_rank(info, edge_featurizer):
"""Construct graphs for reactants and candidate products in a reaction and featurize
their edges
Parameters
----------
info : 4-tuple
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
* candidate_combos : list
candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
* candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for candidate products
* reactant_info : dict
Reaction-related information of reactants.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
reaction_graphs : list of DGLGraphs
DGLGraphs for reactants and candidate products with edge features in edata['he'],
where the first graph is for reactants.
"""
reactant_mol, candidate_combos, candidate_bond_changes, reactant_info = info
# Graphs for reactants and candidate products
reaction_graphs = []
# Get graph for the reactants
reactant_graph = mol_to_bigraph(reactant_mol, edge_featurizer=edge_featurizer,
canonical_atom_order=False)
reaction_graphs.append(reactant_graph)
candidate_bond_changes_no_score = [
(atom1, atom2, change_type)
for (atom1, atom2, change_type, score) in candidate_bond_changes]
# Prepare common components across all candidate products
breaking_reactant_neighbors = []
common_src_list = []
common_dst_list = []
common_edge_feats = []
num_bonds = reactant_mol.GetNumBonds()
for j in range(num_bonds):
bond = reactant_mol.GetBondWithIdx(j)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
u_sort, v_sort = min(u, v), max(u, v)
# Whether a bond in reactants might get broken
if (u_sort, v_sort, 0.0) not in candidate_bond_changes_no_score:
common_src_list.extend([u, v])
common_dst_list.extend([v, u])
common_edge_feats.extend([reactant_graph.edata['he'][2 * j],
reactant_graph.edata['he'][2 * j + 1]])
else:
breaking_reactant_neighbors.append((
u_sort, v_sort, bond.GetBondTypeAsDouble()))
for combo in candidate_combos:
combo_src_list = deepcopy(common_src_list)
combo_dst_list = deepcopy(common_dst_list)
combo_edge_feats = deepcopy(common_edge_feats)
candidate_bond_end_atoms = [
(atom1, atom2) for (atom1, atom2, change_type, score) in combo]
for (atom1, atom2, change_type) in breaking_reactant_neighbors:
if (atom1, atom2) not in candidate_bond_end_atoms:
# If a bond might be broken in some other combos but not this,
# add it as a negative sample
combo.append((atom1, atom2, change_type, 0.0))
for (atom1, atom2, change_type, score) in combo:
if change_type == 0:
continue
combo_src_list.extend([atom1, atom2])
combo_dst_list.extend([atom2, atom1])
feats = one_hot_encoding(change_type, [1.0, 2.0, 3.0, 1.5, -1])
if (atom1, atom2) in reactant_info['ring_bonds']:
feats[-1] = 1
feats = torch.tensor(feats).float()
combo_edge_feats.extend([feats, feats.clone()])
combo_edge_feats = torch.stack(combo_edge_feats, dim=0)
combo_graph = DGLGraph()
combo_graph.add_nodes(reactant_graph.number_of_nodes())
combo_graph.add_edges(combo_src_list, combo_dst_list)
combo_graph.edata['he'] = combo_edge_feats
reaction_graphs.append(combo_graph)
return reaction_graphs
class WLNRankDataset(object):
"""Dataset for ranking candidate products with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
mode : str
'train', 'val', or 'test', indicating whether the dataset is used for training,
validation or test.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom formal charge, atom degree, atom explicit valence, atom implicit valence,
aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type
and whether bond is in ring.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
raw_file_path,
candidate_bond_path,
mode,
node_featurizer=default_node_featurizer_rank,
edge_featurizer=default_edge_featurizer_rank,
size_cutoff=100,
max_num_changes_per_reaction=5,
num_candidate_bond_changes=16,
max_num_change_combos_per_reaction=150,
num_processes=1):
super(WLNRankDataset, self).__init__()
assert mode in ['train', 'val', 'test'], \
"Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
self.mode = mode
self.ignore_large_samples = False
self.size_cutoff = size_cutoff
path_to_reaction_file = raw_file_path + '.proc'
if not os.path.isfile(path_to_reaction_file):
print('Pre-processing graph edits from reaction data')
process_file(raw_file_path, num_processes)
self.reactant_mols, self.product_mols, self.real_bond_changes, \
self.ids_for_small_samples = self.load_reaction_data(path_to_reaction_file, num_processes)
self.candidate_bond_changes = self.load_candidate_bond_changes(candidate_bond_path)
self.num_candidate_bond_changes = num_candidate_bond_changes
self.max_num_changes_per_reaction = max_num_changes_per_reaction
self.max_num_change_combos_per_reaction = max_num_change_combos_per_reaction
self.node_featurizer = node_featurizer
self.edge_featurizer = edge_featurizer
def load_reaction_data(self, file_path, num_processes):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
all_reactant_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
all_product_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for products if the dataset is for training and
None otherwise.
all_real_bond_changes : list of list
``all_real_bond_changes[i]`` gives a list of tuples, which are ground
truth bond changes for a reaction.
ids_for_small_samples : list of int
Indices for reactions whose reactants do not contain too many atoms
"""
print('Stage 1/2: loading reaction data...')
all_reactant_mols = []
all_product_mols = []
all_real_bond_changes = []
ids_for_small_samples = []
with open(file_path, 'r') as f:
lines = f.readlines()
def _update_from_line(id, loaded_result):
reactants_mol, product_mol, reaction_real_bond_changes = loaded_result
if reactants_mol is None:
return
all_product_mols.append(product_mol)
all_reactant_mols.append(reactants_mol)
all_real_bond_changes.append(reaction_real_bond_changes)
if reactants_mol.GetNumAtoms() <= self.size_cutoff:
ids_for_small_samples.append(id)
if num_processes == 1:
for id, li in enumerate(tqdm(lines)):
loaded_line = load_one_reaction_rank(li)
_update_from_line(id, loaded_line)
else:
with Pool(processes=num_processes) as pool:
results = pool.map(
load_one_reaction_rank,
lines, chunksize=len(lines) // num_processes)
for id in range(len(lines)):
_update_from_line(id, results[id])
return all_reactant_mols, all_product_mols, all_real_bond_changes, ids_for_small_samples
def load_candidate_bond_changes(self, file_path):
"""Load candidate bond changes predicted by a WLN for reaction center prediction.
Parameters
----------
file_path : str
Path to a file of candidate bond changes for each reaction.
Returns
-------
all_candidate_bond_changes : list of list
``all_candidate_bond_changes[i]`` gives a list of tuples, which are candidate
bond changes for a reaction.
"""
print('Stage 2/2: loading candidate bond changes...')
with open(file_path, 'r') as f:
lines = f.readlines()
all_candidate_bond_changes = []
for li in tqdm(lines):
all_candidate_bond_changes.append(
load_candidate_bond_changes_for_one_reaction(li))
return all_candidate_bond_changes
def ignore_large(self, ignore=True):
"""Whether to ignore reactions where reactants contain too many atoms.
Parameters
----------
ignore : bool
If ``ignore``, reactions where reactants contain too many atoms will be ignored.
"""
self.ignore_large_samples = ignore
def __len__(self):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
if self.ignore_large_samples:
return len(self.ids_for_small_samples)
else:
return len(self.reactant_mols)
def __getitem__(self, item):
"""Get the i-th datapoint.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
list of B + 1 DGLGraph
The first entry in the list is the DGLGraph for the reactants and the rest are
DGLGraphs for candidate products. Each DGLGraph has edge features in edata['he'] and
node features in ndata['hv'].
candidate_scores : float32 tensor of shape (B, 1)
The sum of scores for bond changes in each combo, where B is the number of combos.
labels : int64 tensor of shape (1, 1), optional
Index for the true candidate product, which is always 0 with pre-processing. This is
returned only when we are not in the training mode.
valid_candidate_combos : list, optional
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction. Each tuple is of form (atom1, atom2,
change_type, score). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants
real_bond_changes : list of tuples
Ground truth bond changes in a reaction. Each tuple is of form (atom1, atom2,
change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product
"""
if self.ignore_large_samples:
item = self.ids_for_small_samples[item]
raw_candidate_bond_changes = self.candidate_bond_changes[item]
real_bond_changes = self.real_bond_changes[item]
reactant_mol = self.reactant_mols[item]
product_mol = self.product_mols[item]
# Get valid candidate products, candidate bond changes considered and reactant info
valid_candidate_combos, candidate_bond_changes, reactant_info = \
pre_process_one_reaction(
(raw_candidate_bond_changes, real_bond_changes,
reactant_mol, product_mol),
self.num_candidate_bond_changes, self.max_num_changes_per_reaction,
self.max_num_change_combos_per_reaction, self.mode)
# Construct DGLGraphs and featurize their edges
g_list = construct_graphs_rank(
(reactant_mol, valid_candidate_combos,
candidate_bond_changes, reactant_info),
self.edge_featurizer)
# Get node features and candidate scores
node_feats, candidate_scores = featurize_nodes_and_compute_combo_scores(
self.node_featurizer, reactant_mol, valid_candidate_combos)
for g in g_list:
g.ndata['hv'] = node_feats
if self.mode == 'train':
labels = torch.zeros(1, 1).long()
return g_list, candidate_scores, labels
else:
reactant_mol = self.reactant_mols[item]
real_bond_changes = self.real_bond_changes[item]
product_mol = self.product_mols[item]
return g_list, candidate_scores, valid_candidate_combos, \
reactant_mol, real_bond_changes, product_mol
class USPTORank(WLNRankDataset):
"""USPTO dataset for ranking candidate products.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
subset,
candidate_bond_path,
size_cutoff=100,
max_num_changes_per_reaction=5,
num_candidate_bond_changes=16,
max_num_change_combos_per_reaction=150,
num_processes=1):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO for product candidate ranking.'.format(subset))
self._subset = subset
if subset == 'val':
mode = 'val'
subset = 'valid'
else:
mode = subset
self._url = 'dataset/uspto.zip'
data_path = get_download_dir() + '/uspto.zip'
extracted_data_path = get_download_dir() + '/uspto'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
super(USPTORank, self).__init__(
raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
candidate_bond_path=candidate_bond_path,
mode=mode,
size_cutoff=size_cutoff,
max_num_changes_per_reaction=max_num_changes_per_reaction,
num_candidate_bond_changes=num_candidate_bond_changes,
max_num_change_combos_per_reaction=max_num_change_combos_per_reaction,
num_processes=num_processes)
@property
def subset(self):
"""Get the subset used for USPTO
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
......
......@@ -47,9 +47,7 @@ class WLNLinear(nn.Module):
stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
nn.init.normal_(self.weight, std=stddev)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
nn.init.constant_(self.bias, 0.0)
def forward(self, feats):
"""Applies the layer.
......@@ -91,19 +89,34 @@ class WLN(nn.Module):
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
project_in_feats : bool
Whether to project input node features. If this is False, we expect node_in_feats
to be the same as node_out_feats. Default to True.
set_comparison : bool
Whether to perform final node representation update mimicking
set comparison. Default to True.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=300,
n_layers=3):
n_layers=3,
project_in_feats=True,
set_comparison=True):
super(WLN, self).__init__()
self.n_layers = n_layers
self.project_node_in_feats = nn.Sequential(
WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
self.project_in_feats = project_in_feats
if project_in_feats:
self.project_node_in_feats = nn.Sequential(
WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
else:
assert node_in_feats == node_out_feats, \
'Expect input node features to have the same size as that of output ' \
'node features, got {:d} and {:d}'.format(node_in_feats, node_out_feats)
self.project_concatenated_messages = nn.Sequential(
WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
nn.ReLU()
......@@ -112,9 +125,11 @@ class WLN(nn.Module):
WLNLinear(2 * node_out_feats, node_out_feats),
nn.ReLU()
)
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.set_comparison = set_comparison
if set_comparison:
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
......@@ -133,7 +148,8 @@ class WLN(nn.Module):
float32 tensor of shape (V, node_out_feats)
Updated node representations.
"""
node_feats = self.project_node_in_feats(node_feats)
if self.project_in_feats:
node_feats = self.project_node_in_feats(node_feats)
for _ in range(self.n_layers):
g = g.local_var()
g.ndata['hv'] = node_feats
......@@ -144,9 +160,12 @@ class WLN(nn.Module):
node_feats = self.get_new_node_feats(
torch.cat([node_feats, g.ndata['hv_new']], dim=1))
g = g.local_var()
g.ndata['hv'] = self.project_node_messages(node_feats)
g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
if not self.set_comparison:
return node_feats
else:
g = g.local_var()
g.ndata['hv'] = self.project_node_messages(node_feats)
g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
......@@ -10,4 +10,5 @@ from .mgcn_predictor import *
from .mpnn_predictor import *
from .acnn import *
from .wln_reaction_center import *
from .wln_reaction_ranking import *
from .weave_predictor import *
"""Weisfeiler-Lehman Network (WLN) for ranking candidate products"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from dgl.nn.pytorch import SumPooling
from ..gnn.wln import WLN
__all__ = ['WLNReactionRanking']
# pylint: disable=W0221, E1101
class WLNReactionRanking(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Candidate Product Ranking
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__ and then
further improved in `A graph-convolutional neural network model for the
prediction of chemical reactivity
<https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract>`__
The model updates representations of nodes in candidate products with WLN and predicts
the score for candidate products to be the real product.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_hidden_feats : int
Size for the hidden node representations. Default to 500.
num_encode_gnn_layers : int
Number of WLN layers for updating node representations.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_hidden_feats=500,
num_encode_gnn_layers=3):
super(WLNReactionRanking, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=num_encode_gnn_layers,
set_comparison=False)
self.diff_gnn = WLN(node_in_feats=node_hidden_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=1,
project_in_feats=False,
set_comparison=False)
self.readout = SumPooling()
self.predict = nn.Sequential(
nn.Linear(node_hidden_feats, node_hidden_feats),
nn.ReLU(),
nn.Linear(node_hidden_feats, 1)
)
def forward(self, reactant_graph, reactant_node_feats, reactant_edge_feats,
product_graphs, product_node_feats, product_edge_feats,
candidate_scores, batch_num_candidate_products):
r"""Predicts the score for candidate products to be the true product
Parameters
----------
reactant_graph : DGLGraph
DGLGraph for a batch of reactants.
reactant_node_feats : float32 tensor of shape (V1, node_in_feats)
Input node features for the reactants. V1 for the number of nodes.
reactant_edge_feats : float32 tensor of shape (E1, edge_in_feats)
Input edge features for the reactants. E1 for the number of edges in
reactant_graph.
product_graphs : DGLGraph
DGLGraph for the candidate products in a batch of reactions.
product_node_feats : float32 tensor of shape (V2, node_in_feats)
Input node features for the candidate products. V2 for the number of nodes.
product_edge_feats : float32 tensor of shape (E2, edge_in_feats)
Input edge features for the candidate products. E2 for the number of edges
in the graphs for candidate products.
candidate_scores : float32 tensor of shape (B, 1)
Scores for candidate products based on the model for reaction center prediction
batch_num_candidate_products : list of int
Number of candidate products for the reactions in the batch
Returns
-------
float32 tensor of shape (B, 1)
Predicted scores for candidate products
"""
# Update representations for nodes in both reactants and candidate products
batch_reactant_node_feats = self.gnn(
reactant_graph, reactant_node_feats, reactant_edge_feats)
batch_product_node_feats = self.gnn(
product_graphs, product_node_feats, product_edge_feats)
# Iterate over the reactions in the batch
reactant_node_start = 0
product_graph_start = 0
product_node_start = 0
batch_diff_node_feats = []
for i, num_candidate_products in enumerate(batch_num_candidate_products):
reactant_node_end = reactant_node_start + reactant_graph.batch_num_nodes[i]
product_graph_end = product_graph_start + num_candidate_products
product_node_end = product_node_start + sum(
product_graphs.batch_num_nodes[product_graph_start: product_graph_end])
# (N, node_out_feats)
reactant_node_feats = batch_reactant_node_feats[reactant_node_start:
reactant_node_end, :]
product_node_feats = batch_product_node_feats[product_node_start: product_node_end, :]
old_feats_shape = reactant_node_feats.shape
# (1, N, node_out_feats)
expanded_reactant_node_feats = reactant_node_feats.reshape((1,) + old_feats_shape)
# (B, N, node_out_feats)
expanded_reactant_node_feats = expanded_reactant_node_feats.expand(
(num_candidate_products,) + old_feats_shape)
# (B, N, node_out_feats)
candidate_product_node_feats = product_node_feats.reshape(
(num_candidate_products,) + old_feats_shape)
# Get the node representation difference between candidate products and reactants
diff_node_feats = candidate_product_node_feats - expanded_reactant_node_feats
diff_node_feats = diff_node_feats.reshape(-1, diff_node_feats.shape[-1])
batch_diff_node_feats.append(diff_node_feats)
reactant_node_start = reactant_node_end
product_graph_start = product_graph_end
product_node_start = product_node_end
batch_diff_node_feats = torch.cat(batch_diff_node_feats, dim=0)
# One more GNN layer for message passing with the node representation difference
diff_node_feats = self.diff_gnn(product_graphs, batch_diff_node_feats, product_edge_feats)
candidate_product_feats = self.readout(product_graphs, diff_node_feats)
return self.predict(candidate_product_feats) + candidate_scores
......@@ -8,7 +8,7 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc
from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter, WeavePredictor
WLNReactionCenter, WLNReactionRanking, WeavePredictor
__all__ = ['load_pretrained']
......@@ -22,7 +22,8 @@ URL = {
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v2.pth'
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v3.pth',
'wln_rank_uspto': 'dgllife/pre_trained/wln_rank_uspto.pth',
}
def download_and_load_checkpoint(model_name, model, model_postfix,
......@@ -84,6 +85,7 @@ def load_pretrained(model_name, log=True):
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
* ``'wln_rank_uspto'``: A WLN model pre-trained on USPTO for candidate product ranking
log : bool
Whether to print progress for model loading
......@@ -162,4 +164,10 @@ def load_pretrained(model_name, log=True):
n_layers=3,
n_tasks=5)
elif model_name == 'wln_rank_uspto':
model = WLNReactionRanking(node_in_feats=89,
edge_in_feats=5,
node_hidden_feats=500,
num_encode_gnn_layers=3)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
......@@ -64,10 +64,23 @@ def test_wln_reaction():
remove_file('test.txt.proc')
# Test configured dataset
dataset = WLNReactionDataset('test.txt', 'test_graphs.bin')
dataset = WLNCenterDataset('test.txt', 'test_graphs.bin')
remove_file('test_graphs.bin')
with open('test_candidate_bond_changes.txt', 'w') as f:
for reac in reactions:
# simulate fake candidate bond changes
candidate_string = ''
for i in range(2):
candidate_string += '{} {} {:.1f} {:.3f};'.format(i+1, i+2, 0.0, 0.234)
candidate_string += '\n'
f.write(candidate_string)
dataset = WLNRankDataset('test.txt.proc', 'test_candidate_bond_changes.txt', 'train')
remove_file('test.txt')
remove_file('test.txt.proc')
remove_file('test_graphs.bin')
remove_file('test_candidate_bond_changes.txt')
if __name__ == '__main__':
test_pubchem_aromaticity()
......
import dgl
import numpy as np
import torch
from dgl import DGLGraph
......@@ -84,5 +85,71 @@ def test_wln_reaction_center():
batch_edge_feats, batch_atom_pair_feats)[0].shape == \
torch.Size([batch_complete_graph.number_of_edges(), 1])
def test_reactant_product_graph(batch_size, device):
edges = (np.array([0, 1, 2]), np.array([1, 2, 2]))
reactant_g = []
for _ in range(batch_size):
reactant_g.append(DGLGraph(edges))
reactant_g = dgl.batch(reactant_g)
reactant_node_feats = torch.arange(
reactant_g.number_of_nodes()).float().reshape(-1, 1).to(device)
reactant_edge_feats = torch.arange(
reactant_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_g = []
batch_num_candidate_products = []
for i in range(1, batch_size + 1):
product_g.extend([
DGLGraph(edges) for _ in range(i)
])
batch_num_candidate_products.append(i)
product_g = dgl.batch(product_g)
product_node_feats = torch.arange(
product_g.number_of_nodes()).float().reshape(-1, 1).to(device)
product_edge_feats = torch.arange(
product_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_scores = torch.randn(sum(batch_num_candidate_products), 1).to(device)
return reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, batch_num_candidate_products
def test_wln_candidate_ranking():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, num_candidate_products = \
test_reactant_product_graph(batch_size=1, device=device)
batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats, batch_product_g, \
batch_product_node_feats, batch_product_edge_feats, batch_product_scores, \
batch_num_candidate_products = test_reactant_product_graph(batch_size=2, device=device)
# Test default setting
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1,
node_hidden_feats=100,
num_encode_gnn_layers=2).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
if __name__ == '__main__':
test_wln_reaction_center()
test_wln_candidate_ranking()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment