Unverified Commit 331337fe authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Clean up for Alchemy dataset (#800)

* Make dataset framework agnostic

* Update

* Update

* Fix

* Remove unused import
parent 9314aabd
......@@ -45,7 +45,7 @@ To customize your own dataset, see the instructions
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
###Dataset
### Dataset
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science.
The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
......
# -*- coding:utf-8 -*-
"""Sample training code
"""
import argparse
import torch as th
import torch
import torch.nn as nn
from dgl.data.chem import alchemy
from dgl import model_zoo
from torch.utils.data import DataLoader
# from Alchemy_dataset import TencentAlchemyDataset, batcher
def train(model="sch",
epochs=80,
device=th.device("cpu"),
device=torch.device("cpu"),
training_set_size=0.8):
print("start")
alchemy_dataset = alchemy.TencentAlchemyDataset()
train_set, test_set = alchemy_dataset.split(train_size=0.8)
train_set, test_set = alchemy_dataset.split(train_size=training_set_size)
train_loader = DataLoader(dataset=train_set,
batch_size=20,
collate_fn=alchemy.batcher(),
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
test_loader = DataLoader(dataset=test_set,
batch_size=20,
collate_fn=alchemy.batcher(),
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
......@@ -42,7 +36,7 @@ def train(model="sch",
loss_fn = nn.MSELoss()
MAE_fn = nn.L1Loss()
optimizer = th.optim.Adam(model.parameters(), lr=0.0001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(epochs):
......@@ -83,7 +77,6 @@ def train(model="sch",
w_mae /= idx + 1
print("MAE (test set): {:.7f}".format(w_mae))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-M",
......@@ -95,8 +88,6 @@ if __name__ == "__main__":
help="number of epochs",
default=250,
type=int)
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
assert args.model in ['sch', 'mgcn',
'mpnn'], "model name must be sch, mgcn or mpnn"
train(args.model, args.epochs, device)
import datetime
import dgl
import numpy as np
import os
import random
import torch
from sklearn.metrics import roc_auc_score
......
......@@ -2,20 +2,20 @@
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import numpy as np
import os
import os.path as osp
import zipfile
import dgl
from dgl.data.utils import download
import pathlib
import pickle
import torch
import zipfile
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import pathlib
from ..utils import get_download_dir, download, _get_dgl_url
import dgl
import dgl.backend as F
from dgl.data.utils import download, get_download_dir
from .utils import mol_to_complete_graph
try:
import pandas as pd
from rdkit import Chem
......@@ -23,27 +23,44 @@ try:
from rdkit import RDConfig
except ImportError:
pass
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'}
class AlchemyBatcher(object):
"""Data structure for holding a batch of data.
class AlchemyBatcher:
Parameters
----------
graph : dgl.BatchedDGLGraph
A batch of DGLGraphs for B molecules
labels : tensor
Labels for B molecules
"""
def __init__(self, graph=None, label=None):
self.graph = graph
self.label = label
def batcher_dev(batch):
"""Batch datapoints
def batcher():
def batcher_dev(batch):
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
labels = torch.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels)
return batcher_dev
Parameters
----------
batch : list
batch[i][0] gives the DGLGraph for the ith datapoint,
and batch[i][1] gives the label for the ith datapoint.
Returns
-------
AlchemyBatcher
An object holding the batch of data
"""
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
labels = F.stack(labels, 0)
class TencentAlchemyDataset(Dataset):
return AlchemyBatcher(graph=batch_graphs, label=labels)
class TencentAlchemyDataset(object):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
......@@ -51,12 +68,15 @@ class TencentAlchemyDataset(Dataset):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
atom_feats_dict : dict
Dictionary for atom features
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
......@@ -87,7 +107,7 @@ class TencentAlchemyDataset(Dataset):
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
atom_feats_dict['pos'].append(F.tensor(geom[u].astype(np.float32)))
atom_feats_dict['node_type'].append(atom_type)
h_u = []
......@@ -105,27 +125,28 @@ class TencentAlchemyDataset(Dataset):
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h)
atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'],
dim=0)
atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = torch.LongTensor(
atom_feats_dict['node_type'])
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['pos'] = F.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(self, mol, self_loop=True):
def alchemy_edges(self, mol, self_loop=False):
"""Featurization for all bonds in a molecule. The bond indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
bond_feats_dict : dict
Dictionary for bond features
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
......@@ -154,50 +175,15 @@ class TencentAlchemyDataset(Dataset):
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = torch.FloatTensor(
bond_feats_dict['e_feat'])
bond_feats_dict['distance'] = torch.FloatTensor(
bond_feats_dict['distance']).reshape(-1, 1)
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
def mol_to_dgl(self, mol, self_loop=False):
"""
Convert RDKit molecule object to DGLGraph
Args:
mol: Chem.rdchem.Mol read from sdf
self_loop: Whetaher to add self loop
Returns:
g: DGLGraph
l: related labels
"""
g = dgl.DGLGraph()
# add nodes
num_atoms = mol.GetNumAtoms()
atom_feats = self.alchemy_nodes(mol)
g.add_nodes(num=num_atoms, data=atom_feats)
if self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
bond_feats = self.alchemy_edges(mol, self_loop)
g.edata.update(bond_feats)
return g
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid',
'test'], "mode should be dev/valid/test"
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
......@@ -224,15 +210,11 @@ class TencentAlchemyDataset(Dataset):
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"),
"rb") as f:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"),
"rb") as f:
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
......@@ -245,15 +227,15 @@ class TencentAlchemyDataset(Dataset):
]]
self.graphs, self.labels = [], []
sdf_dir = pathlib.Path(self.file_dir, "sdf")
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = self.mol_to_dgl(sdf)
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = torch.FloatTensor(label[1].tolist())
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
......@@ -278,10 +260,7 @@ class TencentAlchemyDataset(Dataset):
return g, l
def split(self, train_size=0.8):
"""
Split the dataset into two AlchemySubset for train&test.
"""
"""Split the dataset into two AlchemySubset for train&test."""
assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num],
......@@ -292,15 +271,13 @@ class TencentAlchemyDataset(Dataset):
self.transform)
return train_set, test_set
class AlchemySubset(TencentAlchemyDataset):
"""
Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set.
"""
def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__()
self.graphs = graphs
self.labels = labels
self.mean = mean
......
......@@ -7,7 +7,7 @@ import pickle
import sys
from dgl import DGLGraph
from .utils import smile2graph
from .utils import smile_to_bigraph
class CSVDataset(object):
......@@ -28,9 +28,9 @@ class CSVDataset(object):
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smile2graph: callable, str -> DGLGraph
smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smile_column: str
Column name that including smiles
......@@ -39,7 +39,7 @@ class CSVDataset(object):
Path to store the preprocessed data
"""
def __init__(self, df, smile2graph=smile2graph, smile_column='smiles',
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules:
from ...base import dgl_warning
......@@ -50,9 +50,9 @@ class CSVDataset(object):
self.task_names = self.df.columns.drop([smile_column]).tolist()
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smile2graph)
self._pre_process(smile_to_graph)
def _pre_process(self, smile2graph):
def _pre_process(self, smile_to_graph):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
......@@ -66,7 +66,7 @@ class CSVDataset(object):
with open(self.cache_file_path, 'rb') as f:
self.graphs = pickle.load(f)
else:
self.graphs = [smile2graph(s) for s in self.smiles]
self.graphs = [smile_to_graph(s) for s in self.smiles]
with open(self.cache_file_path, 'wb') as f:
pickle.dump(self.graphs, f)
......
......@@ -2,7 +2,7 @@ import numpy as np
import sys
from .csv_dataset import CSVDataset
from .utils import smile2graph
from .utils import smile_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url
try:
......@@ -32,11 +32,11 @@ class Tox21(CSVDataset):
Parameters
----------
smile2graph: callable, str -> DGLGraph
smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
"""
def __init__(self, smile2graph=smile2graph):
def __init__(self, smile_to_graph=smile_to_bigraph):
if 'pandas' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install pandas")
......@@ -48,7 +48,7 @@ class Tox21(CSVDataset):
df = df.drop(columns=['mol_id'])
super().__init__(df, smile2graph, cache_file_path="tox21_dglgraph.pkl")
super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.pkl")
self._weight_balancing()
......
import dgl.backend as F
import numpy as np
from functools import partial
from dgl import DGLGraph
......@@ -136,8 +137,41 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
return {self.atom_data_field: atom_features}
def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturizer(), bond_featurizer=None):
"""Convert SMILES into a DGLGraph.
def mol_to_graph(mol, graph_constructor, atom_featurizer, bond_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
......@@ -152,24 +186,23 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi
Parameters
----------
smiles : str
String of SMILES
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
mol = Chem.MolFromSmiles(smile)
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
......@@ -185,11 +218,146 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi
nodes = g.nodes()
g.add_edges(nodes, nodes)
# Featurization
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
return g
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
def mol_to_bigraph(mol, add_self_loop=False,
atom_featurizer=CanonicalAtomFeaturizer(),
bond_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_bigraph(smile, add_self_loop=False,
atom_featurizer=CanonicalAtomFeaturizer(),
bond_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smile : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
return mol_to_bigraph(mol, add_self_loop, atom_featurizer, bond_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_complete_graph(smile, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smile : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
return mol_to_complete_graph(mol, add_self_loop, atom_featurizer, bond_featurizer)
......@@ -2,10 +2,16 @@
import torch
from .dgmg import DGMG
from .gcn import GCNClassifier
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .sch import SchNetModel
from ...data.utils import _get_dgl_url, download
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
......@@ -82,6 +88,12 @@ def load_pretrained(model_name, log=True):
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
elif model_name == 'MGCN_Alchemy':
model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
if log:
print('Pretrained model loaded')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment