Unverified Commit 38b9c0f8 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Allow Generating Vocabulary from a New Dataset (#1577)

* Generate vocabulary from a new dataset

* CI"
parent a936f9d9
...@@ -25,7 +25,7 @@ molecules for training and 5000 molecules for validation. ...@@ -25,7 +25,7 @@ molecules for training and 5000 molecules for validation.
### Preprocessing ### Preprocessing
Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with Class `JTNNDataset` will process a SMILES string into a dict, consisting of a junction tree, a graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use. encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage ## Usage
...@@ -33,54 +33,17 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use. ...@@ -33,54 +33,17 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use.
### Training ### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory. with preprocessed vocabulary, and save model checkpoint periodically in the current working directory.
```
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```
Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.
#### Dataset configuration
If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
### Evaluation ### Evaluation
To start evaluation, use `python reconstruct_eval.py`, and following arguments To start evaluation, use `python reconstruct_eval.py`. By default, we will perform evaluation with
``` DGL's pre-trained model. During the evaluation, the program will print out the success rate of
-t TRAIN, Training file name (default: test) molecule reconstruction.
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
### Pre-trained models ### Pre-trained models
Below gives the statistics of pre-trained `JTNN_ZINC` model. Below gives the statistics of our pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy | Pre-trained model | % Reconstruction Accuracy
| ------------------ | ------- | ------------------ | -------
...@@ -96,3 +59,43 @@ Please put this script at the current directory (`examples/pytorch/model_zoo/che ...@@ -96,3 +59,43 @@ Please put this script at the current directory (`examples/pytorch/model_zoo/che
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png) ![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules #### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png) ![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Dataset configuration
If you want to use your own dataset, please create a file with one SMILES a line as below
```
CCO
Fc1ccccc1
```
You can generate the vocabulary file corresponding to your dataset with `python vocab.py -d X -v Y`, where `X`
is the path to the dataset and `Y` is the path to the vocabulary file to save. An example vocabulary file
corresponding to the two molecules above will be
```
CC
CF
C1=CC=CC=C1
CO
```
If you want to develop a model based on DGL's pre-trained model, it's important to make sure that the vocabulary
generated above is a subset of the vocabulary we use for the pre-trained model. By running `vocab.py` above, we
also check if the new vocabulary is a subset of the vocabulary we use for the pre-trained model and print the
result in the terminal as follows:
```
The new vocabulary is a subset of the default vocabulary: True
```
To train on this new dataset, run
```
python train.py -t X
```
where `X` is the path to the new dataset. If you want to use the vocabulary generated above, also add `-v Y`, where
`Y` is the path to the vocabulary file we just saved.
To evaluate on this new dataset, run `python reconstruct_eval.py` with arguments same as above.
"""Generate vocabulary for a new dataset."""
if __name__ == '__main__':
import argparse
import os
import rdkit
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from jtnn.mol_tree import DGLMolTree
parser = argparse.ArgumentParser('Generate vocabulary for a molecule dataset')
parser.add_argument('-d', '--data-path', type=str,
help='Path to the dataset')
parser.add_argument('-v', '--vocab', type=str,
help='Path to the vocabulary file to save')
args = parser.parse_args()
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
vocab = set()
with open(args.data_path, 'r') as f:
for line in f:
smiles = line.strip()
mol = DGLMolTree(smiles)
for i in mol.nodes_dict:
vocab.add(mol.nodes_dict[i]['smiles'])
with open(args.vocab, 'w') as f:
for v in vocab:
f.write(v + '\n')
# Get the vocabulary used for the pre-trained model
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
default_vocab = set()
with open(vocab_file, 'r') as f:
for line in f:
default_vocab.add(line.strip())
print('The new vocabulary is a subset of the default vocabulary: {}'.format(
vocab.issubset(default_vocab)))
...@@ -268,7 +268,6 @@ class Model(nn.Module): ...@@ -268,7 +268,6 @@ class Model(nn.Module):
# load graph data # load graph data
from dgl.contrib.data import load_data from dgl.contrib.data import load_data
import numpy as np
data = load_data(dataset='aifb') data = load_data(dataset='aifb')
num_nodes = data.num_nodes num_nodes = data.num_nodes
num_rels = data.num_rels num_rels = data.num_rels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment