Unverified Commit cd9fb7ba authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor GCN on Tox21 (#766)

* [Model zoo] Model zoo (#765)

* tox21

* fix ci

* fix ci

* fix urls to url

* add doc

* remove binary

* model zoo

* test

* markdown

* fix typo

* fix typo

* fix typo

* raise error

* fix lint

* remove unnecessary

* fix doc

* fix

* fix

* fix

* fix

* fix

* fix

* Update

* CI

* Fix

* Fix

* Fix

* Fix

* Fix

* CI
parent bdcba9c8
Model Zoo
==========
Here are examples of using the model zoo.
# Property Prediction
## Classification
Classification tasks require assigning discrete labels to a molecule, e.g. molecule toxicity.
### Datasets
- **Tox21**. The ["Toxicology in the 21st Century" (Tox21)](https://tripod.nih.gov/tox21/challenge/) initiative created
a public database measuring toxicity of compounds, which has been used in the 2014 Tox21 Data Challenge. The dataset
contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and
stress response pathways. Each target yields a binary prediction problem. MoleculeNet [1] randomly splits the dataset
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models
- **Graph Convolutional Network** [2]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
### Usage
To train a model from scratch, simply call `python classification.py`. To skip training and use the pre-trained model,
call `python classification.py -p`.
We use GPU whenever it is available.
### Performance
#### GCN on Tox21
| Source | Averaged ROC-AUC Score |
| ---------------- | ---------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.827 |
Note that due to some possible randomness you may get different numbers for DeepChem example and our model. To get
match exact results for this model, please use the pre-trained model as in the usage section.
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
### References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
from dgl.data import Tox21
from dgl.data.utils import split_dataset
from dgl import model_zoo
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
learning_rate = 0.001
num_epochs = 100
set_random_seed()
# Interchangeable with other Dataset
dataset = Tox21()
atom_data_field = 'h'
trainset, valset, testset = split_dataset(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(
trainset, batch_size=batch_size, collate_fn=collate_molgraphs)
val_loader = DataLoader(
valset, batch_size=batch_size, collate_fn=collate_molgraphs)
test_loader = DataLoader(
testset, batch_size=batch_size, collate_fn=collate_molgraphs)
if args.pre_trained:
num_epochs = 0
model = model_zoo.chem.load_pretrained('GCN_Tox21')
else:
# Interchangeable with other models
model = model_zoo.chem.GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
dataset.task_pos_weights).to(device), reduction='none')
optimizer = Adam(model.parameters(), lr=learning_rate)
stopper = EarlyStopping(patience=10)
model.to(device)
for epoch in range(num_epochs):
model.train()
print('Start training')
train_meter = Meter()
for batch_id, batch_data in enumerate(train_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels, mask = atom_feats.to(device), labels.to(device), mask.to(device)
logits = model(atom_feats, bg)
# Mask non-existing labels
loss = (loss_criterion(logits, labels)
* (mask != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, num_epochs, batch_id + 1, len(train_loader), loss.item()))
train_meter.update(logits, labels, mask)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, train_roc_auc))
val_meter = Meter()
model.eval()
with torch.no_grad():
for batch_id, batch_data in enumerate(val_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
val_meter.update(logits, labels, mask)
val_roc_auc = val_meter.roc_auc_averaged_over_tasks()
if stopper.step(val_roc_auc, model):
break
print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, val_roc_auc, stopper.best_score))
test_meter = Meter()
model.eval()
for batch_id, batch_data in enumerate(test_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
test_meter.update(logits, labels, mask)
print('test roc-auc score {:.4f}'.format(test_meter.roc_auc_averaged_over_tasks()))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args()
main(args)
{
"cells": [
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"from dgl import model_zoo\n",
"import torch\n",
"import rdkit\n",
"from rdkit import Chem\n",
"from rdkit.Chem.Draw import IPythonConsole\n",
"from dgl.data.chem.utils import smile2graph\n",
"import dgl"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model...\n"
]
},
{
"data": {
"text/plain": [
"GCNClassifier(\n",
" (gcn_layers): ModuleList(\n",
" (0): GCNLayer(\n",
" (graph_conv): GraphConv(in=74, out=64, normalization=False, activation=<function relu at 0x7efd7f46e158>)\n",
" (dropout): Dropout(p=0.0)\n",
" (res_connection): Linear(in_features=74, out_features=64, bias=True)\n",
" (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): GCNLayer(\n",
" (graph_conv): GraphConv(in=64, out=64, normalization=False, activation=<function relu at 0x7efd7f46e158>)\n",
" (dropout): Dropout(p=0.0)\n",
" (res_connection): Linear(in_features=64, out_features=64, bias=True)\n",
" (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (atom_weighting): Sequential(\n",
" (0): Linear(in_features=64, out_features=1, bias=True)\n",
" (1): Sigmoid()\n",
" )\n",
" (soft_classifier): MLPBinaryClassifier(\n",
" (predict): Sequential(\n",
" (0): Dropout(p=0.0)\n",
" (1): Linear(in_features=128, out_features=64, bias=True)\n",
" (2): ReLU()\n",
" (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (4): Linear(in_features=64, out_features=12, bias=True)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = model_zoo.chem.load_pretrained(\"GCN_Tox21\")\n",
"model.eval()\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',\n",
" 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE',\n",
" 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"smiles = \"CC[NH+](CC)c1ccc(/C=C2\\Oc3c(ccc(OCC(N)=O)c3C)C2=O)cc1\""
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcIAAACWCAIAAADCEh9HAAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nO3deViU19k/8O8sjBBWlaCM4IJgBFxxF3wxkQQX1JiKOy4xmqa2Y2xqqfHNi7bmDVlsp6mvqIkL7hoXhLhChAR+0YgI4pAYIQJVdlHZt5k5vz8OTgm4IDzMM8D9uXL1Yh6Gc+6xcnv2I2GMgRBCSEtJxQ6AEELaN0qjhBDSKpRGCSGkVSiNEkJIq1AaJYSQVqE0+gy3b9+eNWtWcXGx2IEQ8pzOn8fLL8PfH8uWoaJC7Gg6MgkteHq61157LTo62s3N7fTp025ubmKHQ0jz5Odj6lTExsLWFrt2ITUVarXYMXVYlEafIS8vb8aMGVevXu3WrduJEyd8fX3FjoiQZtixAyUlWLsWAPR6eHjg5k2xY+qwqFP/DI6OjnFxcTNnzrx//76/v//+/fvFjoiQZsjNhVJZ/7VUCqkUWq2oAXVklEafzdLS8vjx48HBwTU1NYsXL96wYQM14YmpUyqRm1v/tV4PnQ5yuagBdWSURptFJpOFhoZu375dJpNt3Lhx2bJltbW1YgdFyJNNn47Dh1FaCgB792LyZLED6shobPT5nD9/fs6cOaWlpd7e3hEREfb29mJHRMgTnDmDTz6BQgGlEp9/jvBwfP01oqKgUIgdWUdDafS5paamBgQE3Llzx9XV9cyZMzR9T0xXSQnMzPDCCwAwbBiuX8epU5gxQ+ywOhrq1D+3IUOGXL582cvLKyMjY9y4cQkJCWJHRMjj/O1v6NkTBw/Wv5w/HwAOHRIxoo6K0mhLKJXKb7/9NiAgoLi42M/P7xD91SQmyMkJ1dU4fLj+5fz5kEgQGYnyclHD6oAojbaQlZVVRETEqlWrampqFi5cuGHDBrEjIuTXfvMbWFggNhY5OQDQuze8vVFZiVOnxI6so6E02nIymWzLli1qtVoikWzcuHH58uV1dXViB0XIIzY2mDwZej2++qr+CfXr2wZNMQng5MmTixYtqqys9PPzO3bsmK2trdgREQIA+OorzJmD0aPxww8AUFSEXr0AIDcXtMhEONQaFcCsWbNiY2N79OgRExPj4+OTnZ0tdkSEAACmT4etLa5cQXo6ALz4IiZNQl0djh0TO7IOhdKoMEaPHn3p0iUPDw+NRjNu3LirV6+KHREhgLl5/fKmI0fqn1C/vg1QGhVMv379Ll++PGXKlLy8vIkTJ56igXxiCnjePHCg/uWsWbCwYMnJ5XzeiQiB0qiQrK2tIyIili5dWlFR8e6771ZXV4sdEen0Xn0VDg64eRPXrwOAtfXRFSvstdpt1CAVDqVRgSkUit27d48dO7awsPD48eNih0M6PbmcBQbeHz784rlz/IHZxIn3q6posbOAKI22lcrKSkdHR7GjIAT/b+7c7snJb4aF8WU5U6dOtbOzu3bt2o8//ih2aB0EpVHh6fV6jUYDYMiQIWLHQgi8fXz69u2bnZ39/fffA+jSpcsbb7wB4OjRo2KH1kFQGhXerVu3ysvLe/fuTec/EVMgkUjmzJkDwNCRnz9/PoADhnkn0jqURoWXnJwMYPjw4WIHQkg9njePHj2q1WoBvPLKK0qlMiMjIykpSezQOgJKo8JLSUkBMGzYMLEDIaTesGHDPD09i4qKYmJiAEil0tmzZ6NB+5S0BqVR4VEaJSZo7ty5aNKvP3TokE6nEzOsDoH21AuvZ8+eBQUFWVlZffr0ETsWQur98ssvbm5uVlZWBQUFFhYWANzc3DIyMmJjYydOnCh2dO0btUYFlpOTU1BQYGdn17t3b7FjIeQ/+vfvP3LkyLKystOnT+t0uoSEhPHjx3t7e0ullARaiy4LFFhammzChP1OTpkSiUTsWAj5lfnz5ycmJu7ateuDDz4wNzfno0/z5s0bMWLEiBEjPD09PTw8PDw86K/u86JOvcA2bcIHH+CPf8TmzWKHQsiv5eXlff7558ePH09PT3dxcbG3t79x40ZVVVXD90ya9E9ANXw4hg3DsGF46SW6mPnZ6E9IYCkpAEDTS8QE2dvbX7t2LT093d3d/fvvv7ezs9PpdNnZ2WlpaUlJSUlJSYmJiTrdmLg4fPNN/Y9YWGDwYAwbhhEjMHo0RozAzz/D1RXbtsHcHMOGQa3Gnj0AsG0btFr8/vfifTzxUBoVWHIyANCaUWKCVq9efeHCBXt7+6ioKDs7OwAymczFxcXFxWX69On8PXfu1KWkIDkZKSlISUFmJq5cwZUrcHPD0aMYOBCbNyMsTNSPYXoojQqptBSZmejSBS+9JHYohPxaaGhoWFiYhYVFVFRU//79n/Q2Z2czZ2c8Sqp4+BA8q/LL7ceNQ3o6CguNEnH7QWlUSCkpYAyDB8PMTOxQyK9VVVV9+eWXV69eDQ8PFzsWERw7dmz9+vVSqXT//v1jx45t/g/a2WHiRPAFUXzAatUqbNkCpbL+DefO1X83NxcqlaBBtx8dba1Dfn6+iLVTj95kyeXyjRs37t27l58a06kkJiYuWbJEr9d/+umn/FCS1pg5E9HRqKiofzl5MuLiEBeHP/6xtXG2Xx0qjX788ceenp7ffvutWAHQ/JLJMjMz65zbHzMzM6dPn15ZWfnWW2/9UYhUJ5Vi6VLs3t36kjqOjpNGGWOXLl26f/++v7//wYMHRYmBWqOmzHCsUedZ5FdSUjJjxoyCgoLJkyeHCTcxtGQJioqEKqxDYB2IVqsNDg4GIJFIQkJC9Hq9MWuvqWEKBZNKWVmZMaslzaXT6ZycnABcunRJ7FiMoba2dtKkSQA8PT0fPnxonEq/+oq99RabMYPdv2+cCk2CbMOGDeLmcQFJpVI/Pz+lUnn27NnY2NisrKypU6fKZDLj1C6RYP58vPoqBg0yToXk+Ugkktzc3EuXLllaWk6ZMkXscNoWY+zNN9+MiIhQKpVxcXEODg7GqdfDAzNmoKICNTXo1884dZoAsfN4mzh37pyNjQ0Ab2/voqIiscMhpoJffO3g4FBXVyd2LG2LN49eeOGFK1euGLnqqiq2eDGrqTFytWLqmGmUMXb9+nVnZ2cArq6ut27dEiWG774TpVryNO7u7gAuXLggchznzrGJE9lrr7GlS1l5ubBlHz58WCKRyGSyiIgIYUt+pro69s47LCvLyNWKrONMMTUyZMiQy5cve3l5ZWRkjBs3Lj4+3vgxnD9v/DrJMzS6TkMc+flYtw4RETh/HhMmYP16AcuOj49fsmQJY+wf//jHzJkzBSy5OT77DDdv4qOP6qdbOwux83jbKisrCwgIANClS5eDBw8KUmZyMpNKWXo6Y4yFhbHdu1lyMluypP67YWHsX/9iiYns7beZlxd7+2127Jgg1RJhpKenA7CxsamsrBQtiO3b2Sef1H+t07GXXmKMsZ072aFD7KefmFbb4oIzMjJefPFFAGvWrBEiUNIsHbY1yllZWUVERKxataqmpmbhwoVCzafxncVPMXIktm3DlCnYtg2/+Y0gdRJhuLq6enl5lZaWnnt0b7sIcnP/sw1IKoVUCq0W69dj/ny4u+OFF+DpicWL8c9/IiEBlZXNLLW4uHjKlClFRUXTpk379NNP2yp40kQHT6MAZDLZli1b1Gq1RCLZuHHj8uXL6+rqWlkm7Sxu1wz3Z4gWgVKJ3Nz6r/V68Gs83n4bM2agd2/U1uLHH7FvH959FxMmwMYGnp5p77772WefxcTEFBcXP7bI6urqmTNnpqene3l5HTlyxGgLVAg61XmjJ0+eXLRoUWVlpZ+f37Fjx2xtbZ+3hJoaXLiAXr2wdSumTUNyMpTK+uPCJk/GwIHAo53FnfO4sPYiNzfX2dlZoVDk5+e34K+BAPLyEBCA2FjY2GDPHiQn45///M93S0pw4waSkvDjj0hLw9WrqKn5u6/ve4+25zk6OvIjlvlxyx4eHgCCgoIOHDjQq1evy5cv8+WxxHjEHlUwqh9++KFHjx4ABg0alNXs2USdjsXHM5WKvfgiA1hYGFu+nOl0bOxY9tlnjx8bJSbO19cXQHh4uLErLipiv/0tKytjp08zX1/26qtsyZJnbNioqmKJibH79v3ud78bN26cpaVlo19hOzs7fmKTjY3NjRs3jPVJyH90rjTKGMvMzOT/ejs6OiYmJj79zT/8wN59lymVDKj/b+hQtmULW76cMca2bWOenpRG26Vt27YBmDx5slFrrapi3t4MYEuXtqaYnJycyMjIkJCQgICAnj17ApDL5ZMmTTp//rxQkZLn0unSKGOstLSUb2KxtLR87MK6H3/88YMPPpg1a50he/bvz/77v1laGmOMJSfXp9GqKubg8LQ0WlnJJk1iRl+6R56tuLhYoVDI5fKCggIjVanXs4ULGcB69WJ37ghYcE5OTlxcnIAFkufVGdMoY6yuru63v/0tAJlMFhoayh/euXNHrVZ7e3vzvpKVlbWLS8Xq1ezy5RbWsnUrA5hMxtRqwSInQpk2bRqA//u//zNSfe+/zwBmbc2uXzdSjcRYOmka5TZt2sQvQfTz8/Px8TFciNi1a9fly5d/88032las4OPUaiaVMoCtXMnabv9hVVVVZGTkgwcPWh9w57F//34APj4+xqhs1y4GMLmcUb+7I+rUaZQx9tVXX1lYWAwcOBCAubl5QEBAeHh4RUWFoFUwCwsGMH9/VlIiYMFMq9XGx8evXLmSHyAwZMgQf3//EmHr6LjKy8stLS0lEklmZmbb1hQXxxQKBrCtW9u2IiKSzp5GGWOpqamXLl3av39/WZudcHfpEnNwYAAbMoT9+9+tLU2v1yckJKxatarhsT2DBg3iydTLyysnJ0eIqDu+efPmAfj444/bsI4ff2RduzKABQe3YS1EVJRGjSQ9nQ0YwAA2cmTm9ZaOjmk0mpCQEFdXV0P27Nu3b3Bw8M2bNxljv/zyC29WK5XKpKQkQcPvmE6dOgVg2LBhbVR+fn5+9ZgxDGCBgUyna6NaiOgojRrPvXts8uSHrq6e1tbWZ8+ebf4PZmdnq9VqLy8vQ/bs1auXSqWKj49v9M7i4mK+ItLKyioqKkrQ8Dugmpqa7t27A9BoNIIXXlFRMXr06GE9ehTPns1E3L9P2h6lUaOqqakJCgrCoy2qT3/zvXv3tm/f7u3t3XDuKygoKDIy8ilTSTU1NYsWLWpmFeStt94C8MEHHwhbrE6n45fH9evXLz8/X9jCiamhNGpser0+JCSEZ0aVSqVr0td7+PBheHh4QECAXF5//bWFhUVgYGBkZGRNk7Nwy8rK9u3bV/7rAyufWQUxOHbsmEQisbCwGDFiRFBQkFqtjo+Pb/0c43vvvUfbijoPSqPiOHz4sLm5OYBZs2bxX1q+aCkoKOiFF17g2bNLly585UDTua/q6mr+ZisrKwCPPQOwaRWkkYqKilGjRjXdIS2XywcPHhwUFLR58+aLFy/ef857hb744gsAZmZmMTExbRQ5MSmURkVz8eLFrl278kn2wMBAa2tr/jssk8leeeWVL7/8sulvr1arjYmJWb58Of9BAFKpdMKECWfOnHlsFQkJCfb29gBGjx5NXctGGva7f/755/j4eLVaHRQUNGLECIVC0SixOjo6BgQEBAcHh4eHazSap9yWePbsWblcLpFI9uzZY8yPQ0REaVRM6enpffv27datG/9d9fDwCA0NfexyJY1GExwc7OjoaPjF9vDwCAkJycjIeGYVAwYMAODk5JSSktI2n6Ndekq/u7a2VqPRhIeHq1Qqb2/vpqeB2Nraent7q1Sq8PDwq1evVldX8x/UaDT8yKj/+Z//MfoHIqKhNCqyjRs38tZiOj9P/9fS0tJCQkJ4HuT69OmjUqmuXbvW/Cru3bs3YcIEANbW1k9qt3Y2z9Xv1mq1aWlpBw4cWLt2rZ+fH2/gN2Rubj5q1KiFCxfyk+cXLlxo5Mu9ibgojYqMT9xv27at4cNGu/sBKJVKvsKpZb+f1dXV/KxiuVweFhYmUOzt1ZP63cXFxUePHk1PT3/mH3JOTk50dDQfBPDw8JBK648/d3Z2dnd3NzROSSdBaVRkgwcPBnC5wfEnFy5cMKxw6tat21tvvXXx4sXWz7bz6XtebGeevn9KvzsqKor/+VhbWzecuH/mrU0PHz6Mi4vj/yKOGTOmzWInJorSqJiqq6vNzMxkMlnDmfSqqioHB4eAgICjR482XeHUSjt37jQzMwMwe/ZsMe90E0lubm7v3r0BzJ07t2mTMzY2dtq0aUrDLUmPmJmZDR06dMmSJWq1Oi4u7uHDh48tvKKiwsrKSiKR3L59u+0/CjEhlEbFlJiYyCeLGj2va7vDoBiLjo7mzbGxY8ca77RNE8C3FQGYMGHC0/vdBQUF58+fDw0NnTdv3sCBAw19dgMXF5c33njjr3/969dff93wB/nIyUcffdTGH4WYFkqjYuITHQsWLDByvRqNpk+fPjwd/PTTT0auXRRarZZf2t6/f//CwsLn+tmampqGE/eGhb0ABgwY0PCdkZGRAIYMGSJo7MTUyUHEk5KSAmDYsGFGrtfT0/PSpUvTp09PSkry8fE5efIkn8rvwNasWXPq1Knu3bufOXOGz6c3n0Kh8PT09PT0XLx4MQCtVnvz5s2UlJSUlBQ7O7uG75w8eXL37t1TU1M1Gs2gQYOE/ADElImdxzu18ePHA4iOjhal9vLy8pkzZ8pksq1bt3bsBTr/+Mc/ACgUitjY2Laua8WKFQDWr1/f1hUR09GJLlg2NXq93tbWtry8vKioqOlSROPIyMhwc3NzcHAoKCgQJQAjOH369MyZM/V6/d69e/mhLW0qNjb2lVde6dOnT2ZmpmHFBenYGo+dE6NJT08vLy93cnISK4cCSE1NBTB8+HCxAmhr165dmzt3rk6n27RpkxFyKABfX18nJ6fs7OwffvjBCNURU0BpVDR8YNTIKWzr1q1//vOf09LSRIzBaHJycmbOnFlRUbFs2bL333/fOJVKpdI5c+YAOHTokHFqJKKjNCoaUeaXDh8+/Omnn969e1fEGIyjrKxs6tSpd+/e9fX15bfSGw1f9nT48GGtVmvMeolYKI2KJjk5GcZNYYwx3os3VGr8GIxDq9XOnj07NTXV3d395MmTTU9salMjR44cMGBAYWFhbGysMeslYqE0Khrjd6h/+eWXkpISpVLZo0cPAMXFxXfv3rW0tHRzczNaDMahUqkuXLhgb28fGRlpOFTQmHiDlPr1nQSlUXHk5eUVFBTY2dn17dvXaJU26sLzpujQoUOb7tIxqvPn8fLL8PfHsmWoqHiuHy0pKXnw4MGDBw+ysrJu3759+/bta9eurV27NiwszMLCIioqquH1f8bE0+jx48erqqpECYAYEy2/F4chhRlzTUyj9q9J9Ojz87FuHWJjYWuLXbuwfj3UanzyCaKj8eBB/XsePgRflvfoC5u6urLy8icV6ejoKJFI9u/fP3bsWGN8hMd56aWXvLy8rl27du7cuVmzZokVBjEOSqPiEGWK3JC7G8YgchqNjMT8+bC1BYClS+HhAQC3biEm5ik/ZPaon25jYyOTyQDY2tryNrWVlZVGo5HL5S+//HIbh/4M8+fPv3bt2qFDhyiNdniURsUhSgprlLtNYrVTbi4MI7NSKaRSaLV47z3Mmwc7O/Cmuq0t+LDDoy/u2dpKnjwQ4e/vf+HChWPHjvENRWJZsGBBcHBwVFRUSUkJPwuGdFQ0NioO43eoCwsLc3Nzra2tXVxcAFRVVd26dUsul4u89VupRG5u/dd6PXQ6yOVwd4efH0aOxIgRGDECrq5wcYGLC7p3R9eu6Nr1KTkUJjO9o1QqfXx8qqurT506JW4kpK1RGhVBWVnZ7du3FQqFu7u70SptNKGUmpqq1WoHDhzIbw8VzfTpOHwYpaUAsHcvJk9ufZFvvPGGhYXFt99+m5OT0/rSWsNEEjppa5RGRXD9+nW9Xj9o0CBjrmc0xR49AEdHTJ2KoUPh44O4OHz4YYtLys7Ozs/PB2BjYzNlyhS9Xn/06FHhAm2JOXPmKBSKmJiYDnxkAQGlUVGIMkXeaDTWJOaXuKtXkZWFFSuwZw+srFpWxieffNKvX7/PP/+cvzSRZmC3bt1effVVrVZ7/PhxcSMhbYrSqAge2xJcs2ZNWFhYW1faaNGo+K3RigrExUEqbWV3fuzYsYyxgwcP8hPLAgICbG1tExMTb926JVCgLWQiCZ20LXHP6euc+FUWx48fNzyJj4/n/3esW7euLY7+LC8vl8lkZmZmVVVVjDGtVsuPcC8uLha8rudz4gQD2PjxrSxGr9fz8/y///57/mTJkiUANm7c2OoQW6W8vNzS0lIikWRmZoobCWk71BoVQWhoqJmZmUqlun79On/i4+Oza9cuMzOzjz76KDAwUPCtL6mpqTqdzsPDg08o3bp1q7Kysk+fPt26dRO2oufGL+OcPr2VxUgkksDAQDRo94nSDKyurvb39//6668NTywtLceMGSOVSr///ntjRkKMSuw83hndv39/4sSJAKysrKKiogzPY2Ji+KUUgl82t3XrVgBLlizhLw8cOADg9ddfF7CKltDpWM+eDGAaTesLS0pKAvDiiy/yCwHr6ur40QHXrl1rfeHNodfr+RF5/fv3N1zpmpaWZmNj4+DgkJWVZZwwiPFRa1QEXbt2PX/+fFBQUHl5+euvv75lyxb+fNKkSQkJCX379r18+fK4ceNu3rwpVI0mOr+UmIj8fPTpA0/P1hfm5eXl4eFRVFR08eJFAHK5fPbs2TBig/T9998/evSojY3NiRMn+BqMe/fuzZgxo7S09L/+67+cnZ2NEwYRgdh5vPPS6/UhISF8T71KpdLpdPx5Xl7eqFGjAHTt2lWou4OOHz++YsWKxMRE/vLVV18FEBERIUjhLbd+PQOYSiVUeRs2bACwdOlS/jIhIQGAs7Oz4c+27ezcuROAmZnZhQsX+JPKykq+qX/UqFEVFRVtHQAREaVRkR0+fJiPV86aNcvwy1ZVVTV37lwACoUiPDxc8Er51ZjZ2dmCl/x8hgxhABPuRr+MjAwANjY2lZWVjDG9Xs8P0Pruu++EquKxYmNjefMzLCyMP9HpdG+88QaAfv365efnt2ntRHSURsWXkJDAr2MaPXp0Xl4ef8jbqgAkEklwcLCA0/d37tzhTV2RbwPNzmYSCbOyYtXVApY6cuRINFgF8Ze//AXAO++8I2AVjaSlpfER7XXr1hkevvfeewBsbW1v3LjRdlUTE0Fp1CSkp6cPGDAAQK9evZKTkw3Pv/jiC7lcDmDOnDl8rVLrRUVFAXjllVcEKa3F9mzbtmn06J9Wrxa22M2bNwOYPXs2f8nXQtjb29fW1gpbEZeXl8cXWgUGBhqGDr744gvewY+JiWmLSompoTRqKu7duzdhwgQA1tbWZ86cMTw/f/68jY0NgPHjxxcWFramCp1OFxsbywdeDYlGLP7+/gAEH7LIycmRyWTm5uYPHz7kT/jZK6dPnxa2IvaE0c+zZ8/K5XKJRLJnzx7BaySmidKoCamurl6wYAEAuVxuGGVjjKWmpvbu3RuAq6vrzz//3IKSNRpNSEhIv379+LyinZ2dXC7fuXOncLE/n7KyMnNzc5lMVlRUJHjhfDGZIYtt2rQJwKJFi4StRafT8YNE+/XrZ1idptFo+Jl4ISEhwlZHTBmlUdNiGBJtNH1/584dvj7Jx8en+aXdvHkzJCSEDxdwLi4u69evX7VqVdMqjInvMff29m6Lwrdv3w7A39+fv/zll18kEomlpWV5ebmAtaxZswZAt27dbt68yZ/k5ubyVU1z584VedyZGBelUVO0c+dOMzMz3vXmk86MsdLS0qCgoIyMjGf++N27d9Vqtbe3t+GGku7du69cuTI+Pt7w6/3YKoxm6dKlAEJDQ9ui8Pv37ysUCrlcbpgiHzNmDIAjR44IVcWOHTsajX5WVFTwPb4TJkyoFnTSjJg+SqMmKjo6mncPm7+j6f79++Hh4QEBAfxeDT5THBQUFBkZ+dgJlhZU0QIVFRWGYUpOp9Px/UVpaWltVOn06dMBbNmyhb9Uq9UQbteWYfTTMLCr1WpnzpwJoH///q0cvybtEaVR06XRaPgssIuLy08//fSkt1VWVkZGRgYGBhpOLzU3Nw8ICAgPD3/mqu9mVtECWq02Ojo6KCjI2tq60fkgfHe5i4uLgNU1cvDgwYaDBgUFBXK5XKFQtP4olhs3bjQd/fzDH/7Am/y3bt1qZfmkPaI0atLy8vL4Qshu3brFxcU1/JYhT1k9OqNTJpN5e3tv3769pKTkuaoQcNOUTqeLi4t7++23u3fvzqOSSCSNpnfWrVsHYLXQS50aKi8vX7NmzeXLlw1P/Pz8AHz55ZetLPnIkSMKhWLRokWG4ZG///3v/J+uhISEVhZO2ilKo6auvLycdxgVCsXevXt1Ol18fLxKpeI7kbgRI0ao1eoW75ZpVEXLCuGLAfhFT5yHh0dISEjTBhpfgWTkNZX8cBYnJ6dPP/00Ojr63r17LS7qypUrhtHPr7/+WiaTSSSSffv2CRQpaX8ojbYDWq129erVvGXX8Gi7oUOHfvzxx4Ls6Wy4aSokJKT5E81ZWVmhoaEDBw40ROXs7KxSqeLj45u+OTMzc926dRKJxMLCQth586f76aef3N3dG13Z4ujo6Ofnp1KpwsPDNRpNC1YsJCUlWVpaAvjwww/bImzSXlAabTe2b98+ZswYS0vL3r17q1Sqtjj/bceOHXzT1Ny5c5++aarpYoBu3boFBQVFR0c3TcHFxcXh4eF+fn78zfxfgsGDB6ekpAj+EZoKDw/nyc7d3f1vf/vb7373u/Hjx1s1ua3Ezs5u4sSJ77777p49e65fv/7MXU937951cnICsGzZMiN8CmLKKI22J1qttjkLnlrDsGnK29v7sWvjHzx4MHHiROmjK45tbW2XLFly7tw5rVbb6J0PHz7cvXv3a6+9Zlg5YGlpuWDBgs8++8zNzY2PJ4aGhrbdwtWqqiqVSj6OlTMAAAW9SURBVMWr5scSNvxuTk5OZGRkSEhIQECAo6Njo6xqZmbm4eERFBSkVqubDgKUlpYOGTIEgK+vr+FoUdJpURoljaWmpvJl5E/aNOXh4dGlSxe+GKBp37y6ujoyMjIoKIi3Afncl5+fX3h4eFlZGX9PRUWFSqXijdNJkyb9+9//FvxTZGVl8YWc5ubmarX6me/Pyck5ffr0hx9+GBgY6Orqamhlc1Kp1M3NLTAw8H//93+joqJefvll3ry9f/++4JGTdofSKHmMnJwcLy8vvoin6SlzKSkpjZaCMsa0Wm18fPzKlSt5Y5anHm9vb7Va/aSllGfPnlUqlbxJu337dgHjj4iI4Kcuubm5tWzooLS09OrVq+Hh4SqVytvb28LComFWtbGx6dGjB12vRDhKo+TxysrKAgICAHTp0uXAgQNPeefVq1dVKlXPnj0bzdHfvn37mbUUFhbynekAZs+e3ZoJdK6uri44OJi3JV9//fUHDx60skCutrY2JSVlz549q1ev9vX1PXHiRFJSkiAlkw6A0ih5Iq1W+/vf/94wfd/ou3yFk6urqyF79u3bNzg42LDHvPmOHj3atWtXAD169Gh4OdXzunPnzvjx4wHI5fLQ0FDa2E6Mg9IoeQa1Ws0nlN58883a2trs7Gy1Ws27/JyTk9OTVjg1X1ZWFj+ZSSKRrFy50jCK2nwxMTF8j6mzs7PhmmVCjIDSKHm2EydO8HvteZ7i7O3t33nnne+++06oqXa9Xq9Wq7t06QKgX79+zc/LWq02JCSE5/pp06a1fscnIc+F0ihplitXrri6uvr6+lpYWAQGBkZGRrbRQh+NRjN8+HDeMQ8ODn5mLYWFhfyGPplMFhISIsq5f6STkzDGQEgzaLXavLw8e3v7RtPWgqurq/vwww83bdqk0+lGjRq1d+/ehrukGvruu+/mzZuXl5fn4OBw4MABvnGeECOjNEpM1KVLlxYvXpyRkWFubr5hw4a1a9ca1vwDYIx9/vnna9euraur8/X1PXToUNMl9IQYB6VRYrrKysr+9Kc/8TOS/fz8du/ezfdfFhcXL168+MyZMxKJ5A9/+MPmzZv5HlZCREFplJi6s2fPLl++PC8vz9bW9l//+pe7u/ucOXMyMzO7d+++b9++KVOmiB0g6ewojZJ2oKCgYMWKFfxqaJlMptPpxo0bd+TIEb5plRBxSZ/9FkLElpub6+jouGDBAmdn50mTJgUEBFhYWOzbt0/suAgBABpRIu3AnTt3duzYMWPGjIyMDKlUevLkyTlz5jQ8epUQEVEaJe0AP3G5traWf8H/t66uTuSwCAFAnXrSLvC7oGtra/lLQ1YVMyZCHqE0StqBRs3PRlmVEHFRGiXtQKPmJ3XqiUmhNEraAerUE1NGaZS0A4/t1FNrlJgISqOkHXhsp55ao8REUBol7UCjvElTTMSk0LpR0g5YKRR/9vJS2tvzl/ZduuybMEHRvbu4URHC0Z560h4UFcHBAS++iMJCALh7F87OcHLCnTtiR0YIdepJu6BQAIChF29m9quXhIiK0ihpDxrlTZ5VaaaemAZKo6Q9aJQ3GzVOCREVpVHSHsjlkEqh1UKvByiNEtNCaZS0E7xfzxukZmaQSFBXB5ogJSaA0ihpJxoNjzbMqoSIitIoaScaDY/SZD0xGZRGSTvRaDyUJuuJyaBdTKSdaNSLd3D41UtCxEO7mEg7UVICS0vQffTE9FCnnrQTtrb45hu8/DL8/bFsGSoqxA6IkHrUGiXtRH4+pk5FbCxsbbFrF1JToVaLHRMhAKVR0m7s2IGSEqxdCwB6PTw8cPOm2DERAlCnnrQbublQKuu/lkrrNzURYgIojZJ2QqlEbm7913o9dDqabiImgtIoaSemT8fhwygtBYC9ezF5stgBEVKPxkZJ+3HmDD75BAoFlEps2QIrK7EDIgSgNEoIIa1EnXpCCGkVSqOEENIq/x9vVZpe6phePAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<rdkit.Chem.rdchem.Mol at 0x7efd736958f0>"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = Chem.MolFromSmiles(smiles)\n",
"m"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"g = smile2graph(smiles)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DGLGraph(num_nodes=28, num_edges=60,\n",
" ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}\n",
" edata_schemes={})"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"bg = dgl.batch([g])"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([28, 74])"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bg.ndata['h'].shape"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/playground/mz_dgl/python/dgl/base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
" warnings.warn(msg, warn_type)\n"
]
}
],
"source": [
"logits = model(bg.ndata['h'], bg)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"preds = logits.data.numpy() > 0.5"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>NR-AR</th>\n",
" <th>NR-AR-LBD</th>\n",
" <th>NR-AhR</th>\n",
" <th>NR-Aromatase</th>\n",
" <th>NR-ER</th>\n",
" <th>NR-ER-LBD</th>\n",
" <th>NR-PPAR-gamma</th>\n",
" <th>SR-ARE</th>\n",
" <th>SR-ATAD5</th>\n",
" <th>SR-HSE</th>\n",
" <th>SR-MMP</th>\n",
" <th>SR-p53</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" NR-AR NR-AR-LBD NR-AhR NR-Aromatase NR-ER NR-ER-LBD NR-PPAR-gamma \\\n",
"0 False False True False True False True \n",
"\n",
" SR-ARE SR-ATAD5 SR-HSE SR-MMP SR-p53 \n",
"0 False True False False True "
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"pd.DataFrame(preds, columns=tasks)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Environment (conda_miniconda3-latest)",
"language": "python",
"name": "conda_miniconda3-latest"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import dgl
import numpy as np
import os
import random
import torch
from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
class Meter(object):
def __init__(self):
self.mask = []
self.y_pred = []
self.y_true = []
def update(self, y_pred, y_true, mask):
self.y_pred.append(y_pred)
self.y_true.append(y_true)
self.mask.append(mask)
# Todo: Allow different evaluation metrics
def roc_auc_averaged_over_tasks(self):
"""Compute roc-auc score for each task and return the average."""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
# Todo: support categorical classes
# This assumes binary case only
y_pred = torch.sigmoid(y_pred)
n_tasks = y_true.shape[1]
total_score = 0
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].cpu().numpy()
task_y_pred = y_pred[:, task][task_w != 0].cpu().detach().numpy()
total_score += roc_auc_score(task_y_true, task_y_pred)
return total_score / n_tasks
class EarlyStopping(object):
def __init__(self, patience=10, filename="es_checkpoint.pth"):
assert not os.path.exists(filename), \
'Filename {} is occupied. Either rename it or delete it.'.format(filename)
self.patience = patience
self.counter = 0
self.filename = filename
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
# Todo: this is not true for all metrics.
elif score < self.best_score:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.'''
torch.save(model.state_dict(), self.filename)
def load_checkpoint(self, model):
'''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename))
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader
Parameters
----------
data : list of 4-tuples
Each tuple is for a single datapoint, consisting of
A SMILE, a DGLGraph, all-task labels and all-task weights
Returns
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
weights : Tensor of dtype float32 and shape (B, T)
Batched datapoint weights. T is the number of
total tasks.
"""
smiles, graphs, labels, mask = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
mask = torch.stack(mask, dim=0)
return smiles, bg, labels, mask
# Customize Dataset
Generally we follow the practise of PyTorch.
A Dataset class should implement `__getitem__(self, index)` and `__len__(self)`method
```python
class CustomDataset:
def __init__(self):
# Initialize Dataset and preprocess data
def __getitem__(self, index):
# Return the corresponding DGLGraph/label needed for training/evaluation based on index
return self.graphs[index], self.labels[index]
def __len__(self):
return len(self.graphs)
```
DGL supports various backends such as MXNet and PyTorch, therefore we want our dataset to be also backend agnostic.
We prefer user using numpy array in the dataset, and not including any operator/tensor from the specific backend.
If you want to convert the numpy array to the corresponding tensor, you can use the following code
```python
import dgl.backend as F
# g is a DGLGraph, h is a numpy array
g.ndata['h'] = F.zerocopy_from_numpy(h)
# Now g.ndata is a PyTorch Tensor or a MXNet NDArray based on backend used
```
If your dataset is in `.csv` format, you may use
[`CSVDataset`](https://github.com/dmlc/dgl/blob/master/python/dgl/data/chem/csv_dataset.py).
......@@ -8,14 +8,15 @@ import sys
from dgl import DGLGraph
from .utils import smile2graph
from ..utils import download, get_download_dir, _get_dgl_url, Subset
class CSVDataset(object):
"""CSVDataset
This is a general class for loading data from csv or pd.DataFrame.
In data pre-processing, we set non-existing labels to be 0, and returning mask with 1 where label exists.
In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs will be saved for reloading so that we do not need to reconstruct them every time.
......@@ -38,13 +39,16 @@ class CSVDataset(object):
Path to store the preprocessed data
"""
def __init__(self, df, smile2graph=smile2graph, smile_column='smiles', cache_file_path="csvdata_dglgraph.pkl"):
def __init__(self, df, smile2graph=smile2graph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install RDKit (Recommended Version is 2018.09.3)")
dgl_warning(
"Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df
self.smiles = self.df[smile_column].tolist()
self.task_names = self.df.columns.drop([smile_column]).tolist()
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smile2graph)
......@@ -62,17 +66,14 @@ class CSVDataset(object):
with open(self.cache_file_path, 'rb') as f:
self.graphs = pickle.load(f)
else:
self.graphs = []
for id, s in enumerate(self.smiles):
self.graphs.append(smile2graph(s))
self.graphs = [smile2graph(s) for s in self.smiles]
with open(self.cache_file_path, 'wb') as f:
pickle.dump(self.graphs, f)
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values))
self.mask = F.zerocopy_from_numpy(~np.isnan(_label_values).astype(np.float32))
self.labels = np.nan_to_num(_label_values).astype(np.float32)
self.mask = (~np.isnan(_label_values)).astype(np.float32)
def __getitem__(self, item):
"""Get the ith datapoint
......@@ -88,7 +89,9 @@ class CSVDataset(object):
Tensor of dtype float32
Weights of the datapoint for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
return self.smiles[item], self.graphs[item], \
F.zerocopy_from_numpy(self.labels[item]), \
F.zerocopy_from_numpy(self.mask[item])
def __len__(self):
"""Length of Dataset
......
......@@ -3,15 +3,13 @@ import sys
from .csv_dataset import CSVDataset
from .utils import smile2graph
from ..utils import get_download_dir, download, _get_dgl_url, Subset
from ..utils import get_download_dir, download, _get_dgl_url
try:
import pandas as pd
except ImportError:
pass
class Tox21(CSVDataset):
_url = 'dataset/tox21.csv.gz'
......@@ -49,6 +47,7 @@ class Tox21(CSVDataset):
self.id = df['mol_id']
df = df.drop(columns=['mol_id'])
super().__init__(df, smile2graph, cache_file_path="tox21_dglgraph.pkl")
self._weight_balancing()
......
import dgl.backend as F
import numpy as np
import os
import pickle
from dgl import DGLGraph
......@@ -30,7 +28,6 @@ def one_hot_encoding(x, allowable_set):
"""
return list(map(lambda s: x == s, allowable_set))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers
......@@ -45,8 +42,7 @@ class BaseAtomFeaturizer(object):
def __call__(self, mol):
return NotImplementedError
class DefaultAtomFeaturizer(BaseAtomFeaturizer):
class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
The atom features include:
......@@ -76,7 +72,7 @@ class DefaultAtomFeaturizer(BaseAtomFeaturizer):
"""
def __init__(self, atom_data_field='h'):
super(DefaultAtomFeaturizer, self).__init__()
super(CanonicalAtomFeaturizer, self).__init__()
self.atom_data_field = atom_data_field
@property
......@@ -140,8 +136,7 @@ class DefaultAtomFeaturizer(BaseAtomFeaturizer):
return {self.atom_data_field: atom_features}
def smile2graph(smile, add_self_loop=False, atom_featurizer=None, bond_featurizer=None):
def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturizer(), bond_featurizer=None):
"""Convert SMILES into a DGLGraph.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
......@@ -163,7 +158,7 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=None, bond_featurize
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
......
Model Zoo API
==================
We provide two major APIs for the model zoo. For the time being, only PyTorch is supported.
- `model_zoo.chem.[Model_Name]` to load the model skeleton
- `model_zoo.chem.load_pretrained([Pretrained_Model_Name])` to load the model with pretrained weights
Models would be placed in `python/dgl/model_zoo/chem`.
Each Model should contain the following elements:
- Papers related to the model
- Model's input and output
- Dataset compatible with the model
- Documentation for all the customizable configs
- Credits (Contributor infomation)
"""Package for model zoo."""
from . import chem
# DGL for Chemistry
With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in drug discovery.
As drug discovery is known to be costly and time consuming, deep learning on graphs can be potentially beneficial for
improving the efficiency of drug discovery [1], [2].
With pre-trained models and training scripts, we hope this model zoo will be helpful for both
the chemistry community and the deep learning community to further their research.
## Dependencies
Before you proceed, make sure you have installed the dependencies below:
- PyTorch 1.2
- Check the [official website](https://pytorch.org/) for installation guide
- pandas 0.24.2
- Install with either `conda install pandas` or `pip install pandas`
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
- requests 2.22.0
- Install with `pip install requests`
- scikit-learn 0.21.2
- Install with `pip install -U scikit-learn` or `conda install scikit-learn`
## Property Prediction
[**Get started with our example code!**](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/chem/property_prediction)
To evaluate molecules for drug candidates, we need to know their properties and activities. In practice, this is
mostly achieved via wet lab experiments. We can cast the problem as a regression or classification problem.
In practice, this can be quite difficult due to the scarcity of labeled data.
### Featurization and Representation Learning
Fingerprint has been a widely used concept in cheminformatics. Chemists developed hand designed rules to convert
molecules into binary strings where each bit indicates the presence or absence of a particular substructure. The
development of fingerprints makes the comparison of molecules a lot easier. Previous machine learning methods are
mostly developed based on molecule fingerprints.
Graph neural networks make it possible for a data-driven representation of molecules out of the atoms, bonds and
molecular graph topology, which may be viewed as a learned fingerprint [3].
### Models
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction.
## References
[1] Chen et al. (2018) The rise of deep learning in drug discovery. *Drug Discov Today* 6, 1241-1250.
[2] Vamathevan et al. (2019) Applications of machine learning in drug discovery and development.
*Nature Reviews Drug Discovery* 18, 463-477.
[3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
# pylint: disable=C0111
"""Model Zoo Package"""
from .gcn import GCNClassifier
from .pretrain import load_pretrained
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, feats, bg):
"""Update atom representations
Parameters
----------
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(feats, bg)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class MLPBinaryClassifier(nn.Module):
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class GCNClassifier(nn.Module):
def __init__(self, in_feats, gcn_hidden_feats, n_tasks, classifier_hidden_feats=128,
dropout=0., atom_data_field='h', atom_weight_field='w'):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
atom_data_field : str
Name for storing atom features in DGLGraphs
atom_weight_field : str
Name for storing atom weights in DGLGraphs
"""
super(GCNClassifier, self).__init__()
self.atom_data_field = atom_data_field
self.gcn_layers = nn.ModuleList()
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gcn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
self.atom_weight_field = atom_weight_field
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
self.g_feats = 2 * in_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, feats, bg):
"""Multi-task prediction for a batch of molecules
Parameters
----------
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features
for gcn in self.gcn_layers:
feats = gcn(feats, bg)
# Compute molecule features from atom features
bg.ndata[self.atom_data_field] = feats
bg.ndata[self.atom_weight_field] = self.atom_weighting(feats)
h_g_sum = dgl.sum_nodes(
bg, self.atom_data_field, self.atom_weight_field)
h_g_max = dgl.max_nodes(bg, self.atom_data_field)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
"""Utilities for using pretrained models."""
import torch
from .gcn import GCNClassifier
from ...data.utils import _get_dgl_url, download
def load_pretrained(model_name):
"""Load a pretrained model
Parameters
----------
model_name : str
Returns
-------
model
"""
if model_name == "GCN_Tox21":
print('Loading pretrained model...')
url_to_pretrained = _get_dgl_url('pre_trained/gcn_tox21.pth')
local_pretrained_path = 'pre_trained.pth'
download(url_to_pretrained, path=local_pretrained_path)
model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=12,
classifier_hidden_feats=64)
checkpoint = torch.load(local_pretrained_path)
model.load_state_dict(checkpoint['model_state_dict'])
return model
else:
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment