Unverified Commit 38b9c0f8 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Allow Generating Vocabulary from a New Dataset (#1577)

* Generate vocabulary from a new dataset

* CI"
parent a936f9d9
......@@ -25,7 +25,7 @@ molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with
Class `JTNNDataset` will process a SMILES string into a dict, consisting of a junction tree, a graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
......@@ -33,54 +33,17 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use.
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```
Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.
#### Dataset configuration
If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
with preprocessed vocabulary, and save model checkpoint periodically in the current working directory.
### Evaluation
To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
To start evaluation, use `python reconstruct_eval.py`. By default, we will perform evaluation with
DGL's pre-trained model. During the evaluation, the program will print out the success rate of
molecule reconstruction.
### Pre-trained models
Below gives the statistics of pre-trained `JTNN_ZINC` model.
Below gives the statistics of our pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
......@@ -96,3 +59,43 @@ Please put this script at the current directory (`examples/pytorch/model_zoo/che
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Dataset configuration
If you want to use your own dataset, please create a file with one SMILES a line as below
```
CCO
Fc1ccccc1
```
You can generate the vocabulary file corresponding to your dataset with `python vocab.py -d X -v Y`, where `X`
is the path to the dataset and `Y` is the path to the vocabulary file to save. An example vocabulary file
corresponding to the two molecules above will be
```
CC
CF
C1=CC=CC=C1
CO
```
If you want to develop a model based on DGL's pre-trained model, it's important to make sure that the vocabulary
generated above is a subset of the vocabulary we use for the pre-trained model. By running `vocab.py` above, we
also check if the new vocabulary is a subset of the vocabulary we use for the pre-trained model and print the
result in the terminal as follows:
```
The new vocabulary is a subset of the default vocabulary: True
```
To train on this new dataset, run
```
python train.py -t X
```
where `X` is the path to the new dataset. If you want to use the vocabulary generated above, also add `-v Y`, where
`Y` is the path to the vocabulary file we just saved.
To evaluate on this new dataset, run `python reconstruct_eval.py` with arguments same as above.
"""Generate vocabulary for a new dataset."""
if __name__ == '__main__':
import argparse
import os
import rdkit
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from jtnn.mol_tree import DGLMolTree
parser = argparse.ArgumentParser('Generate vocabulary for a molecule dataset')
parser.add_argument('-d', '--data-path', type=str,
help='Path to the dataset')
parser.add_argument('-v', '--vocab', type=str,
help='Path to the vocabulary file to save')
args = parser.parse_args()
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
vocab = set()
with open(args.data_path, 'r') as f:
for line in f:
smiles = line.strip()
mol = DGLMolTree(smiles)
for i in mol.nodes_dict:
vocab.add(mol.nodes_dict[i]['smiles'])
with open(args.vocab, 'w') as f:
for v in vocab:
f.write(v + '\n')
# Get the vocabulary used for the pre-trained model
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
default_vocab = set()
with open(vocab_file, 'r') as f:
for line in f:
default_vocab.add(line.strip())
print('The new vocabulary is a subset of the default vocabulary: {}'.format(
vocab.issubset(default_vocab)))
......@@ -268,7 +268,6 @@ class Model(nn.Module):
# load graph data
from dgl.contrib.data import load_data
import numpy as np
data = load_data(dataset='aifb')
num_nodes = data.num_nodes
num_rels = data.num_rels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment