Unverified Commit 039fefc2 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] WLN for Reaction Prediction (#1530)

* Update

* Update

* Update

* Update

* Update

* Fix bug

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* UPdate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Udpate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Finalize
parent 70dc2ee9
...@@ -180,9 +180,10 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True) ...@@ -180,9 +180,10 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True)
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds. Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement | | Model | Original Implementation | DGL Implementation | Improvement |
| ---------------------------------- | ----------------------- | ------------------ | ----------- | | ---------------------------------- | ----------------------- | -------------------------- | ---------------------------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x | | GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x | | AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x | | JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 5095 | 2.3x | | | WLN for reaction center prediction | 11657 | 858 (1 GPU) / 134 (8 GPUs) | 13.6x (1GPU) / 87.0x (8GPUs) |
| WLN for candidate ranking | 40122 | 22268 | 1.8x |
...@@ -41,16 +41,23 @@ Reaction Prediction ...@@ -41,16 +41,23 @@ Reaction Prediction
USPTO USPTO
````` `````
.. autoclass:: dgllife.data.USPTO .. autoclass:: dgllife.data.USPTOCenter
:members: __getitem__, __len__ :members: __getitem__, __len__
:show-inheritance: :show-inheritance:
.. autoclass:: dgllife.data.USPTORank
:members: ignore_large, __getitem__, __len__
:show-inheritance:
Adapting to New Datasets for Weisfeiler-Lehman Networks Adapting to New Datasets for Weisfeiler-Lehman Networks
``````````````````````````````````````````````````````` ```````````````````````````````````````````````````````
.. autoclass:: dgllife.data.WLNReactionDataset .. autoclass:: dgllife.data.WLNCenterDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
.. autoclass:: dgllife.data.WLNRankDataset
:members: ignore_large, __getitem__, __len__
Protein-Ligand Binding Affinity Prediction Protein-Ligand Binding Affinity Prediction
------------------------------------------ ------------------------------------------
......
...@@ -74,6 +74,11 @@ WLN for Reaction Center Prediction ...@@ -74,6 +74,11 @@ WLN for Reaction Center Prediction
.. automodule:: dgllife.model.model_zoo.wln_reaction_center .. automodule:: dgllife.model.model_zoo.wln_reaction_center
:members: :members:
WLN for Ranking Candidate Products
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_ranking
:members:
Protein-Ligand Binding Affinity Prediction Protein-Ligand Binding Affinity Prediction
ACNN ACNN
......
...@@ -7,6 +7,16 @@ An earlier version of the work was published in NeurIPS 2017 as ...@@ -7,6 +7,16 @@ An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some ["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling. slight difference in modeling.
This work proposes a template-free approach for reaction prediction with 2 stages:
1) Identify reaction center (pairs of atoms that will lose a bond or form a bond)
2) Enumerate the possible combinations of bond changes and rank the corresponding candidate products
We provide a jupyter notebook for walking through a demonstration with our pre-trained models. You can
download it with `wget https://data.dgl.ai/dgllife/reaction_prediction_pretrained.ipynb` and you need to put it
in this directory. Below we visualize a reaction prediction by the model:
![](https://data.dgl.ai/dgllife/wln_reaction.png)
## Dataset ## Dataset
The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents, The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
...@@ -29,53 +39,104 @@ whose reaction centers have all been selected. ...@@ -29,53 +39,104 @@ whose reaction centers have all been selected.
We use GPU whenever possible. To train the model with default options, simply do We use GPU whenever possible. To train the model with default options, simply do
```bash ```bash
python find_reaction_center.py python find_reaction_center_train.py
```
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 1/50, iter 8150/20452 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | loss 8.6722 | grad norm 14.0833
``` ```
Once the training process starts, the progress will be printed out in the terminal as follows: Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `reaction_center_config`), we save a checkpoint of
the model and evaluate the model on the validation set. The evaluation result is formatted as follows, where `total samples x` means
the we have trained the model on `x` samples and `acc@k` means top-k accuracy:
```bash ```bash
Epoch 1/50, iter 8150/20452 | time/minibatch 0.0260 | loss 8.4788 | grad norm 12.9927 total samples 800000, (epoch 2/35, iter 2443/2557) | acc@12 0.9278 | acc@16 0.9419 | acc@20 0.9496 | acc@40 0.9596 | acc@80 0.9596 |
Epoch 1/50, iter 8200/20452 | time/minibatch 0.0260 | loss 8.6722 | grad norm 14.0833
``` ```
After an epoch of training is completed, we evaluate the model on the validation set and All model check points and evaluation results can be found under `center_results`. `model_x.pkl` stores a model
print the evaluation results as follows: checkpoint after seeing `x` training samples. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Multi-GPU Training
By default we use one GPU only. We also allow multi-gpu training. To use GPUs with ids `id1,id2,...`, do
```bash ```bash
Epoch 4/50, validation | acc@10 0.8213 | acc@20 0.9016 | python find_reaction_center_train.py --gpus id1,id2,...
``` ```
By default, we store the model per 10000 iterations in `center_results`. A summary of the training speedup with the DGL implementation is presented below.
**Speedup**: For an epoch of training, our implementation takes about 5095s for the first epoch while the authors' | Item | Training time (s/epoch) | Speedup |
implementation takes about 11657s, which is roughly a speedup by 2.3x. | ----------------------- | ----------------------- | ------- |
| Authors' implementation | 11657 | 1x |
| DGL with 1 gpu | 858 | 13.6x |
| DGL with 2 gpus | 443 | 26.3x |
| DGL with 4 gpus | 243 | 48.0x |
| DGL with 8 gpus | 134 | 87.0x |
### Evaluation
```bash
python find_reaction_center_eval.py --model-path X
```
For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`center_results/model_800000.pkl`. The evaluation results will be stored at `center_results/test_eval.txt`.
For model evaluation, we can choose whether to exclude reactants not contributing heavy atoms to the product For model evaluation, we can choose whether to exclude reactants not contributing heavy atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier. (e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do For the easier evaluation, do
```bash ```bash
python find_reaction_center.py --easy python find_reaction_center_eval.py --easy
``` ```
A summary of the model performance is as follows: A summary of the model performance of various settings is as follows:
| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy | | Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- | | --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 | | Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation | 88.8 | 91.6 | 92.9 | | Hard evaluation from authors' code | 87.7 | 90.6 | 92.1 |
| Easy evaluation | 91.0 | 93.7 | 94.9 | | Easy evaluation from authors' code | 90.0 | 92.8 | 94.2 |
| Hard evaluation | 88.9 | 91.7 | 93.1 |
| Easy evaluation | 91.2 | 93.8 | 95.0 |
| Hard evaluation for model trained on 8 gpus | 88.1 | 91.0 | 92.5 |
| Easy evaluation for model trained on 8 gpus | 90.3 | 93.3 | 94.6 |
1. We are able to match the results reported from authors' code for both single-gpu and multi-gpu training
2. While multi-gpu training provides a great speedup, the performance with the default hyperparameters drops slightly.
### Data Pre-processing with Multi-Processing
By default we use 32 processes for data pre-processing. If you encounter an error with
`BrokenPipeError: [Errno 32] Broken pipe`, you can specify a smaller number of processes with
```bash
python find_reaction_center_train.py -np X
```
```bash
python find_reaction_center_eval.py -np X
```
where `X` is the number of processes that you would like to use.
### Pre-trained Model ### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do
```bash ```bash
python find_reaction_center.py -p python find_reaction_center_eval.py
``` ```
### Adapting to a new dataset. ### Adapting to a New Dataset
New datasets should be processed such that each line corresponds to the SMILES for a reaction like below: New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:
...@@ -89,10 +150,127 @@ In addition, atom mapping information is provided. ...@@ -89,10 +150,127 @@ In addition, atom mapping information is provided.
You can then train a model on new datasets with You can then train a model on new datasets with
```bash ```bash
python find_reaction_center.py --train-path X --val-path Y --test-path Z python find_reaction_center_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation as described above.
For evaluation,
```bash
python find_reaction_center_eval.py --eval-path Z
```
where `Z` is the path to the new test set as described above.
## Candidate Ranking
### Additional Dependency
In addition to RDKit, MolVS is an alternative for comparing whether two molecules are the same after sanitization.
- [molvs](https://molvs.readthedocs.io/en/latest/)
### Modeling
For candidate ranking, we assume that a model has been trained for reaction center prediction first.
The pipeline for predicting candidate products given a reaction proceeds as follows:
1. Select top-k bond changes for atom pairs in the reactants, ranked by the model for reaction center prediction.
By default, we use k=80 and exclude reactants not contributing heavy atoms to the ground truth product in
selecting top-k bond changes as in the paper.
2. Filter out candidate bond changes for bonds that are already in the reactants
3. Enumerate possible combinations of atom pairs with up to C pairs, which reflects the number of bond changes
(losing or forming a bond) in reactions. A statistical analysis in USPTO suggests that setting it to 5 is enough.
4. Filter out invalid combinations where 1) atoms in candidate bond changes are not connected or 2) an atom pair is
predicted to have different types of bond changes
(e.g. two atoms are predicted simultaneously to form a single and double bond) or 3) valence constraints are violated.
5. Apply the candidate bond changes for each valid combination and get the corresponding candidate products.
6. Construct molecular graphs for the reactants and candidate products, featurize their atoms and bonds.
7. Apply a Weisfeiler-Lehman Network to the molecular graphs for reactants and candidate products and score them
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python candidate_ranking_train.py -cmp X
```
where `X` is the path to a trained model for reaction center prediction. You can use our
pre-trained model by not specifying `-cmp`.
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 6/6, iter 16439/20061 | time 1.1124 | accuracy 0.8500 | grad norm 5.3218
Epoch 6/6, iter 16440/20061 | time 1.1124 | accuracy 0.9500 | grad norm 2.1163
```
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `candidate_ranking_config`),
we save a checkpoint of the model and evaluate the model on the validation set. The evaluation result is formatted
as follows, where `total samples x` means that we have trained the model for `x` samples, `acc@k` means top-k accuracy,
`gfound` means the proportion of reactions where the ground truth product can be recovered by the ground truth bond changes.
We perform the evaluation based on RDKit-sanitized molecule equivalence (marked with `[strict]`) and MOLVS-sanitized
molecule equivalence (marked with `[molvs]`).
```bash
total samples 100000, (epoch 1/20, iter 5000/20061)
[strict] acc@1: 0.7732 acc@2: 0.8466 acc@3: 0.8763 acc@5: 0.8987 gfound 0.9864
[molvs] acc@1: 0.7779 acc@2: 0.8523 acc@3: 0.8826 acc@5: 0.9057 gfound 0.9953
```
All model check points and evaluation results can be found under `candidate_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples in total. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Evaluation
```bash
python candidate_ranking_eval.py --model-path X -cmp Y
```
where `X` is the path to a trained model for candidate ranking and `Y` is the path to a trained model
for reaction center prediction. For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`candidate_results/model_800000.pkl`. The evaluation results will be stored at `candidate_results/test_eval.txt`. As
in training, you can use our pre-trained model by not specifying `-cmp`.
A summary of the model performance of various settings is as follows:
| Item | Top 1 accuracy | Top 2 accuracy | Top 3 accuracy | Top 5 accuracy |
| -------------------------- | -------------- | -------------- | -------------- | -------------- |
| Authors' strict evaluation | 85.6 | 90.5 | 92.8 | 93.4 |
| DGL's strict evaluation | 85.6 | 90.0 | 91.7 | 92.9 |
| Authors' molvs evaluation | 86.2 | 91.2 | 92.8 | 94.2 |
| DGL's molvs evaluation | 86.1 | 90.6 | 92.4 | 93.6 |
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model,
simply do
```bash
python candidate_ranking_eval.py
```
### Adapting to a New Dataset
You can train a model on new datasets with
```bash
python candidate_ranking_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation set as described in the `Reaction Center Prediction` section.
For evaluation,
```bash
python candidate_ranking_train.py --eval-path Z
``` ```
where `X`, `Y`, `Z` are paths to the new training/validation/test set as described above. where `Z` is the path to the new test set as described in the `Reaction Center Prediction` section.
## References ## References
......
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking, load_pretrained
from torch.utils.data import DataLoader
from configure import candidate_ranking_config, reaction_center_config
from utils import mkdir_p, prepare_reaction_center, collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['test_path'] is None:
test_set = USPTORank(
subset='test', candidate_bond_path=path_to_candidate_bonds['test'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
test_set = WLNRankDataset(
raw_file_path=args['test_path'],
candidate_bond_path=path_to_candidate_bonds['test'], mode='test',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=1, collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
if args['model_path'] is None:
model = load_pretrained('wln_rank_uspto')
else:
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
prediction_summary = candidate_ranking_eval(args, model, test_loader)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(prediction_summary)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set. '
'If None, we will use the default test set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=32,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
import time
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from configure import reaction_center_config, candidate_ranking_config
from utils import prepare_reaction_center, mkdir_p, set_seed, collate_rank_train, \
collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['train_path'] is None:
train_set = USPTORank(
subset='train', candidate_bond_path=path_to_candidate_bonds['train'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
else:
train_set = WLNRankDataset(
raw_file_path=args['train_path'],
candidate_bond_path=path_to_candidate_bonds['train'], mode='train',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
train_set.ignore_large()
if args['val_path'] is None:
val_set = USPTORank(
subset='val', candidate_bond_path=path_to_candidate_bonds['val'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
val_set = WLNRankDataset(
raw_file_path=args['val_path'],
candidate_bond_path=path_to_candidate_bonds['val'], mode='val',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_rank_train,
shuffle=True, num_workers=args['num_workers'])
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers']).to(args['device'])
criterion = CrossEntropyLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
acc_sum = 0
grad_norm_sum = 0
dur = []
total_samples = 0
for epoch in range(args['num_epochs']):
t0 = time.time()
model.train()
for batch_id, batch_data in enumerate(train_loader):
batch_reactant_graphs, batch_product_graphs, \
batch_combo_scores, batch_labels, batch_num_candidate_products = batch_data
batch_combo_scores = batch_combo_scores.to(args['device'])
batch_labels = batch_labels.to(args['device'])
reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(args['device'])
reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(args['device'])
product_node_feats = batch_product_graphs.ndata.pop('hv').to(args['device'])
product_edge_feats = batch_product_graphs.edata.pop('he').to(args['device'])
pred = model(reactant_graph=batch_reactant_graphs,
reactant_node_feats=reactant_node_feats,
reactant_edge_feats=reactant_edge_feats,
product_graphs=batch_product_graphs,
product_node_feats=product_node_feats,
product_edge_feats=product_edge_feats,
candidate_scores=batch_combo_scores,
batch_num_candidate_products=batch_num_candidate_products)
# Check if the ground truth candidate has the highest score
batch_loss = 0
product_graph_start = 0
for i in range(len(batch_num_candidate_products)):
product_graph_end = product_graph_start + batch_num_candidate_products[i]
reaction_pred = pred[product_graph_start:product_graph_end, :]
acc_sum += float(reaction_pred.max(dim=0)[1].detach().cpu().data.item() == 0)
batch_loss += criterion(reaction_pred.reshape(1, -1), batch_labels[i, :])
product_graph_start = product_graph_end
grad_norm_sum += optimizer.backward_and_step(batch_loss)
total_samples += args['batch_size']
if total_samples % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time {:.4f} | ' \
'accuracy {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every'],
(sum(dur) + time.time() - t0) / total_samples * args['print_every'],
acc_sum / args['print_every'],
grad_norm_sum / args['print_every'])
print(progress)
acc_sum = 0
grad_norm_sum = 0
if total_samples % args['decay_every'] == 0:
dur.append(time.time() - t0)
old_lr = optimizer.lr
optimizer.decay_lr(args['lr_decay_factor'])
new_lr = optimizer.lr
print('Learning rate decayed from {:.4f} to {:.4f}'.format(old_lr, new_lr))
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d})\n'.format(
total_samples, epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every']) + candidate_ranking_eval(args, model, val_loader)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
t0 = time.time()
model.train()
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=100,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
...@@ -10,10 +10,33 @@ reaction_center_config = { ...@@ -10,10 +10,33 @@ reaction_center_config = {
'n_layers': 3, 'n_layers': 3,
'n_tasks': 5, 'n_tasks': 5,
'lr': 0.001, 'lr': 0.001,
'num_epochs': 25, 'num_epochs': 18,
'print_every': 50, 'print_every': 50,
'decay_every': 10000, # Learning rate decay 'decay_every': 10000, # Learning rate decay
'lr_decay_factor': 0.9, 'lr_decay_factor': 0.9,
'top_ks': [6, 8, 10], 'top_ks_val': [12, 16, 20, 40, 80],
'top_ks_test': [6, 8, 10],
'max_k': 80 'max_k': 80
} }
# Configuration for candidate ranking
candidate_ranking_config = {
'batch_size': 4,
'hidden_size': 500,
'num_encode_gnn_layers': 3,
'max_norm': 50.0,
'node_in_feats': 89,
'edge_in_feats': 5,
'lr': 0.001,
'num_epochs': 6,
'print_every': 20,
'decay_every': 100000,
'lr_decay_factor': 0.9,
'top_ks': [1, 2, 3, 5],
'max_k': 10,
'max_num_change_combos_per_reaction_train': 150,
'max_num_change_combos_per_reaction_eval': 1500,
'num_candidate_bond_changes': 16
}
candidate_ranking_config['max_norm'] = candidate_ranking_config['max_norm'] * \
candidate_ranking_config['batch_size']
import numpy as np
import time
import torch
from dgllife.data import USPTO, WLNReactionDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from utils import setup, collate, reaction_center_prediction, \
rough_eval_on_a_loader, reaction_center_final_eval
def main(args):
setup(args)
if args['train_path'] is None:
train_set = USPTO('train')
else:
train_set = WLNReactionDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin')
if args['val_path'] is None:
val_set = USPTO('val')
else:
val_set = WLNReactionDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin')
if args['test_path'] is None:
test_set = USPTO('test')
else:
test_set = WLNReactionDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin')
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
if args['pre_trained']:
model = load_pretrained('wln_center_uspto').to(args['device'])
args['num_epochs'] = 0
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor'])
total_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += 1
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
grad_norm_sum += grad_norm
optimizer.step()
scheduler.step()
if total_iter % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
(np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
grad_norm_sum / args['print_every'])
grad_norm_sum = 0
loss_sum = 0
print(progress)
if total_iter % args['decay_every'] == 0:
torch.save(model.state_dict(), args['result_path'] + '/model.pkl')
dur.append(time.time() - t0)
print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
rough_eval_on_a_loader(args, model, val_loader))
del train_loader
del val_loader
del train_set
del val_set
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(args, model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/results.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to training results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('-p', '--pre-trained', action='store_true', default=False,
help='If true, we will directly evaluate a '
'pretrained model on the test set.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
main(args)
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.utils.data import DataLoader
from utils import reaction_center_final_eval, set_seed, collate_center, mkdir_p
def main(args):
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
# Set current device
torch.cuda.set_device(args['device'])
if args['test_path'] is None:
test_set = USPTOCenter('test', num_processes=args['num_processes'])
else:
test_set = WLNCenterDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin',
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
if args['model_path'] is None:
model = load_pretrained('wln_center_uspto')
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(
args, args['top_ks_test'], model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Evaluation')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path where we saved model training and evaluation results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing heavy atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_test']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks_test, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_test']))
mkdir_p(args['result_path'])
main(args)
import numpy as np
import time
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import collate_center, reaction_center_prediction, \
reaction_center_final_eval, mkdir_p, set_seed, synchronize, get_center_subset, \
count_parameters
def load_dataset(args):
if args['train_path'] is None:
train_set = USPTOCenter('train', num_processes=args['num_processes'])
else:
train_set = WLNCenterDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin',
num_processes=args['num_processes'])
if args['val_path'] is None:
val_set = USPTOCenter('val', num_processes=args['num_processes'])
else:
val_set = WLNCenterDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin',
num_processes=args['num_processes'])
return train_set, val_set
def main(rank, dev_id, args):
set_seed()
# Remove the line below will result in problems for multiprocess
if args['num_devices'] > 1:
torch.set_num_threads(1)
if dev_id == -1:
args['device'] = torch.device('cpu')
else:
args['device'] = torch.device('cuda:{}'.format(dev_id))
# Set current device
torch.cuda.set_device(args['device'])
train_set, val_set = load_dataset(args)
get_center_subset(train_set, rank, args['num_devices'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
model.train()
if rank == 0:
print('# trainable parameters in the model: ', count_parameters(model))
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
if args['num_devices'] <= 1:
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
optimizer, max_grad_norm=args['max_norm'])
total_iter = 0
rank_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += args['num_devices']
rank_iter += 1
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
grad_norm_sum += optimizer.backward_and_step(loss)
if rank_iter % args['print_every'] == 0 and rank == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
print(progress)
grad_norm_sum = 0
loss_sum = 0
if total_iter % args['decay_every'] == 0:
optimizer.decay_lr(args['lr_decay_factor'])
if total_iter % args['decay_every'] == 0 and rank == 0:
if epoch >= 1:
dur.append(time.time() - t0)
print('Training time per {:d} iterations: {:.4f}'.format(
rank_iter, np.mean(dur)))
total_samples = total_iter * args['batch_size']
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
t0 = time.time()
model.train()
synchronize(args['num_devices'])
def run(rank, dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=args['num_devices'],
rank=rank)
assert torch.distributed.get_rank() == rank
main(rank, dev_id, args)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Training')
parser.add_argument('--gpus', default='0', type=str,
help='To use multi-gpu training, '
'pass multiple gpu ids with --gpus id1,id2,...')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
parser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
parser.add_argument('--master-port', type=str, default='12345',
help='master port')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_val']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_val']))
mkdir_p(args['result_path'])
devices = list(map(int, args['gpus'].split(',')))
args['num_devices'] = len(devices)
if len(devices) == 1:
device_id = devices[0] if torch.cuda.is_available() else -1
main(0, device_id, args)
else:
if (args['train_path'] is not None) or (args['val_path'] is not None):
print('First pass for constructing DGLGraphs with multiprocessing')
load_dataset(args)
# Subprocesses are not allowed for daemon mode
args['num_processes'] = 1
# With multi-gpu training, the batch size increases and we need to
# increase learning rate accordingly.
args['lr'] = args['lr'] * args['num_devices']
mp = torch.multiprocessing.get_context('spawn')
procs = []
for id, device_id in enumerate(devices):
print('Preparing for gpu {:d}/{:d}'.format(id + 1, args['num_devices']))
procs.append(mp.Process(target=run, args=(
id, device_id, args), daemon=True))
procs[-1].start()
for p in procs:
p.join()
This diff is collapsed.
...@@ -47,9 +47,7 @@ class WLNLinear(nn.Module): ...@@ -47,9 +47,7 @@ class WLNLinear(nn.Module):
stddev = min(1.0 / math.sqrt(self.in_feats), 0.1) stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
nn.init.normal_(self.weight, std=stddev) nn.init.normal_(self.weight, std=stddev)
if self.bias is not None: if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) nn.init.constant_(self.bias, 0.0)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, feats): def forward(self, feats):
"""Applies the layer. """Applies the layer.
...@@ -91,19 +89,34 @@ class WLN(nn.Module): ...@@ -91,19 +89,34 @@ class WLN(nn.Module):
n_layers : int n_layers : int
Number of times for message passing. Note that same parameters Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3. are shared across n_layers message passing. Default to 3.
project_in_feats : bool
Whether to project input node features. If this is False, we expect node_in_feats
to be the same as node_out_feats. Default to True.
set_comparison : bool
Whether to perform final node representation update mimicking
set comparison. Default to True.
""" """
def __init__(self, def __init__(self,
node_in_feats, node_in_feats,
edge_in_feats, edge_in_feats,
node_out_feats=300, node_out_feats=300,
n_layers=3): n_layers=3,
project_in_feats=True,
set_comparison=True):
super(WLN, self).__init__() super(WLN, self).__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.project_node_in_feats = nn.Sequential( self.project_in_feats = project_in_feats
WLNLinear(node_in_feats, node_out_feats, bias=False), if project_in_feats:
nn.ReLU() self.project_node_in_feats = nn.Sequential(
) WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
else:
assert node_in_feats == node_out_feats, \
'Expect input node features to have the same size as that of output ' \
'node features, got {:d} and {:d}'.format(node_in_feats, node_out_feats)
self.project_concatenated_messages = nn.Sequential( self.project_concatenated_messages = nn.Sequential(
WLNLinear(edge_in_feats + node_out_feats, node_out_feats), WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
nn.ReLU() nn.ReLU()
...@@ -112,9 +125,11 @@ class WLN(nn.Module): ...@@ -112,9 +125,11 @@ class WLN(nn.Module):
WLNLinear(2 * node_out_feats, node_out_feats), WLNLinear(2 * node_out_feats, node_out_feats),
nn.ReLU() nn.ReLU()
) )
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False) self.set_comparison = set_comparison
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False) if set_comparison:
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False) self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
def forward(self, g, node_feats, edge_feats): def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations. """Performs message passing and updates node representations.
...@@ -133,7 +148,8 @@ class WLN(nn.Module): ...@@ -133,7 +148,8 @@ class WLN(nn.Module):
float32 tensor of shape (V, node_out_feats) float32 tensor of shape (V, node_out_feats)
Updated node representations. Updated node representations.
""" """
node_feats = self.project_node_in_feats(node_feats) if self.project_in_feats:
node_feats = self.project_node_in_feats(node_feats)
for _ in range(self.n_layers): for _ in range(self.n_layers):
g = g.local_var() g = g.local_var()
g.ndata['hv'] = node_feats g.ndata['hv'] = node_feats
...@@ -144,9 +160,12 @@ class WLN(nn.Module): ...@@ -144,9 +160,12 @@ class WLN(nn.Module):
node_feats = self.get_new_node_feats( node_feats = self.get_new_node_feats(
torch.cat([node_feats, g.ndata['hv_new']], dim=1)) torch.cat([node_feats, g.ndata['hv_new']], dim=1))
g = g.local_var() if not self.set_comparison:
g.ndata['hv'] = self.project_node_messages(node_feats) return node_feats
g.edata['he'] = self.project_edge_messages(edge_feats) else:
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr')) g = g.local_var()
h_self = self.project_self(node_feats) # (V, node_out_feats) g.ndata['hv'] = self.project_node_messages(node_feats)
return g.ndata['h_nbr'] * h_self g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
...@@ -10,4 +10,5 @@ from .mgcn_predictor import * ...@@ -10,4 +10,5 @@ from .mgcn_predictor import *
from .mpnn_predictor import * from .mpnn_predictor import *
from .acnn import * from .acnn import *
from .wln_reaction_center import * from .wln_reaction_center import *
from .wln_reaction_ranking import *
from .weave_predictor import * from .weave_predictor import *
"""Weisfeiler-Lehman Network (WLN) for ranking candidate products"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from dgl.nn.pytorch import SumPooling
from ..gnn.wln import WLN
__all__ = ['WLNReactionRanking']
# pylint: disable=W0221, E1101
class WLNReactionRanking(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Candidate Product Ranking
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__ and then
further improved in `A graph-convolutional neural network model for the
prediction of chemical reactivity
<https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract>`__
The model updates representations of nodes in candidate products with WLN and predicts
the score for candidate products to be the real product.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_hidden_feats : int
Size for the hidden node representations. Default to 500.
num_encode_gnn_layers : int
Number of WLN layers for updating node representations.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_hidden_feats=500,
num_encode_gnn_layers=3):
super(WLNReactionRanking, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=num_encode_gnn_layers,
set_comparison=False)
self.diff_gnn = WLN(node_in_feats=node_hidden_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=1,
project_in_feats=False,
set_comparison=False)
self.readout = SumPooling()
self.predict = nn.Sequential(
nn.Linear(node_hidden_feats, node_hidden_feats),
nn.ReLU(),
nn.Linear(node_hidden_feats, 1)
)
def forward(self, reactant_graph, reactant_node_feats, reactant_edge_feats,
product_graphs, product_node_feats, product_edge_feats,
candidate_scores, batch_num_candidate_products):
r"""Predicts the score for candidate products to be the true product
Parameters
----------
reactant_graph : DGLGraph
DGLGraph for a batch of reactants.
reactant_node_feats : float32 tensor of shape (V1, node_in_feats)
Input node features for the reactants. V1 for the number of nodes.
reactant_edge_feats : float32 tensor of shape (E1, edge_in_feats)
Input edge features for the reactants. E1 for the number of edges in
reactant_graph.
product_graphs : DGLGraph
DGLGraph for the candidate products in a batch of reactions.
product_node_feats : float32 tensor of shape (V2, node_in_feats)
Input node features for the candidate products. V2 for the number of nodes.
product_edge_feats : float32 tensor of shape (E2, edge_in_feats)
Input edge features for the candidate products. E2 for the number of edges
in the graphs for candidate products.
candidate_scores : float32 tensor of shape (B, 1)
Scores for candidate products based on the model for reaction center prediction
batch_num_candidate_products : list of int
Number of candidate products for the reactions in the batch
Returns
-------
float32 tensor of shape (B, 1)
Predicted scores for candidate products
"""
# Update representations for nodes in both reactants and candidate products
batch_reactant_node_feats = self.gnn(
reactant_graph, reactant_node_feats, reactant_edge_feats)
batch_product_node_feats = self.gnn(
product_graphs, product_node_feats, product_edge_feats)
# Iterate over the reactions in the batch
reactant_node_start = 0
product_graph_start = 0
product_node_start = 0
batch_diff_node_feats = []
for i, num_candidate_products in enumerate(batch_num_candidate_products):
reactant_node_end = reactant_node_start + reactant_graph.batch_num_nodes[i]
product_graph_end = product_graph_start + num_candidate_products
product_node_end = product_node_start + sum(
product_graphs.batch_num_nodes[product_graph_start: product_graph_end])
# (N, node_out_feats)
reactant_node_feats = batch_reactant_node_feats[reactant_node_start:
reactant_node_end, :]
product_node_feats = batch_product_node_feats[product_node_start: product_node_end, :]
old_feats_shape = reactant_node_feats.shape
# (1, N, node_out_feats)
expanded_reactant_node_feats = reactant_node_feats.reshape((1,) + old_feats_shape)
# (B, N, node_out_feats)
expanded_reactant_node_feats = expanded_reactant_node_feats.expand(
(num_candidate_products,) + old_feats_shape)
# (B, N, node_out_feats)
candidate_product_node_feats = product_node_feats.reshape(
(num_candidate_products,) + old_feats_shape)
# Get the node representation difference between candidate products and reactants
diff_node_feats = candidate_product_node_feats - expanded_reactant_node_feats
diff_node_feats = diff_node_feats.reshape(-1, diff_node_feats.shape[-1])
batch_diff_node_feats.append(diff_node_feats)
reactant_node_start = reactant_node_end
product_graph_start = product_graph_end
product_node_start = product_node_end
batch_diff_node_feats = torch.cat(batch_diff_node_feats, dim=0)
# One more GNN layer for message passing with the node representation difference
diff_node_feats = self.diff_gnn(product_graphs, batch_diff_node_feats, product_edge_feats)
candidate_product_feats = self.readout(product_graphs, diff_node_feats)
return self.predict(candidate_product_feats) + candidate_scores
...@@ -8,7 +8,7 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc ...@@ -8,7 +8,7 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc
from rdkit import Chem from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \ from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter, WeavePredictor WLNReactionCenter, WLNReactionRanking, WeavePredictor
__all__ = ['load_pretrained'] __all__ = ['load_pretrained']
...@@ -22,7 +22,8 @@ URL = { ...@@ -22,7 +22,8 @@ URL = {
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth', 'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth', 'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v2.pth' 'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v3.pth',
'wln_rank_uspto': 'dgllife/pre_trained/wln_rank_uspto.pth',
} }
def download_and_load_checkpoint(model_name, model, model_postfix, def download_and_load_checkpoint(model_name, model, model_postfix,
...@@ -84,6 +85,7 @@ def load_pretrained(model_name, log=True): ...@@ -84,6 +85,7 @@ def load_pretrained(model_name, log=True):
with a random atom order with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation * ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction * ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
* ``'wln_rank_uspto'``: A WLN model pre-trained on USPTO for candidate product ranking
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
...@@ -162,4 +164,10 @@ def load_pretrained(model_name, log=True): ...@@ -162,4 +164,10 @@ def load_pretrained(model_name, log=True):
n_layers=3, n_layers=3,
n_tasks=5) n_tasks=5)
elif model_name == 'wln_rank_uspto':
model = WLNReactionRanking(node_in_feats=89,
edge_in_feats=5,
node_hidden_feats=500,
num_encode_gnn_layers=3)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log) return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
...@@ -64,10 +64,23 @@ def test_wln_reaction(): ...@@ -64,10 +64,23 @@ def test_wln_reaction():
remove_file('test.txt.proc') remove_file('test.txt.proc')
# Test configured dataset # Test configured dataset
dataset = WLNReactionDataset('test.txt', 'test_graphs.bin') dataset = WLNCenterDataset('test.txt', 'test_graphs.bin')
remove_file('test_graphs.bin')
with open('test_candidate_bond_changes.txt', 'w') as f:
for reac in reactions:
# simulate fake candidate bond changes
candidate_string = ''
for i in range(2):
candidate_string += '{} {} {:.1f} {:.3f};'.format(i+1, i+2, 0.0, 0.234)
candidate_string += '\n'
f.write(candidate_string)
dataset = WLNRankDataset('test.txt.proc', 'test_candidate_bond_changes.txt', 'train')
remove_file('test.txt') remove_file('test.txt')
remove_file('test.txt.proc') remove_file('test.txt.proc')
remove_file('test_graphs.bin') remove_file('test_graphs.bin')
remove_file('test_candidate_bond_changes.txt')
if __name__ == '__main__': if __name__ == '__main__':
test_pubchem_aromaticity() test_pubchem_aromaticity()
......
import dgl import dgl
import numpy as np
import torch import torch
from dgl import DGLGraph from dgl import DGLGraph
...@@ -84,5 +85,71 @@ def test_wln_reaction_center(): ...@@ -84,5 +85,71 @@ def test_wln_reaction_center():
batch_edge_feats, batch_atom_pair_feats)[0].shape == \ batch_edge_feats, batch_atom_pair_feats)[0].shape == \
torch.Size([batch_complete_graph.number_of_edges(), 1]) torch.Size([batch_complete_graph.number_of_edges(), 1])
def test_reactant_product_graph(batch_size, device):
edges = (np.array([0, 1, 2]), np.array([1, 2, 2]))
reactant_g = []
for _ in range(batch_size):
reactant_g.append(DGLGraph(edges))
reactant_g = dgl.batch(reactant_g)
reactant_node_feats = torch.arange(
reactant_g.number_of_nodes()).float().reshape(-1, 1).to(device)
reactant_edge_feats = torch.arange(
reactant_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_g = []
batch_num_candidate_products = []
for i in range(1, batch_size + 1):
product_g.extend([
DGLGraph(edges) for _ in range(i)
])
batch_num_candidate_products.append(i)
product_g = dgl.batch(product_g)
product_node_feats = torch.arange(
product_g.number_of_nodes()).float().reshape(-1, 1).to(device)
product_edge_feats = torch.arange(
product_g.number_of_edges()).float().reshape(-1, 1).to(device)
product_scores = torch.randn(sum(batch_num_candidate_products), 1).to(device)
return reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, batch_num_candidate_products
def test_wln_candidate_ranking():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
reactant_g, reactant_node_feats, reactant_edge_feats, product_g, product_node_feats, \
product_edge_feats, product_scores, num_candidate_products = \
test_reactant_product_graph(batch_size=1, device=device)
batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats, batch_product_g, \
batch_product_node_feats, batch_product_edge_feats, batch_product_scores, \
batch_num_candidate_products = test_reactant_product_graph(batch_size=2, device=device)
# Test default setting
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
model = WLNReactionRanking(node_in_feats=1,
edge_in_feats=1,
node_hidden_feats=100,
num_encode_gnn_layers=2).to(device)
assert model(reactant_g, reactant_node_feats, reactant_edge_feats, product_g,
product_node_feats, product_edge_feats, product_scores,
num_candidate_products).shape == torch.Size([sum(num_candidate_products), 1])
assert model(batch_reactant_g, batch_reactant_node_feats, batch_reactant_edge_feats,
batch_product_g, batch_product_node_feats, batch_product_edge_feats,
batch_product_scores, batch_num_candidate_products).shape == \
torch.Size([sum(batch_num_candidate_products), 1])
if __name__ == '__main__': if __name__ == '__main__':
test_wln_reaction_center() test_wln_reaction_center()
test_wln_candidate_ranking()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment