"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b79803fe089aa5a6ab5baef0545f380ff4ff059b"
Unverified Commit 331337fe authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Clean up for Alchemy dataset (#800)

* Make dataset framework agnostic

* Update

* Update

* Fix

* Remove unused import
parent 9314aabd
...@@ -45,7 +45,7 @@ To customize your own dataset, see the instructions ...@@ -45,7 +45,7 @@ To customize your own dataset, see the instructions
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy. Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
###Dataset ### Dataset
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science. - **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science.
The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database. The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
......
# -*- coding:utf-8 -*-
"""Sample training code
"""
import argparse import argparse
import torch as th import torch
import torch.nn as nn import torch.nn as nn
from dgl.data.chem import alchemy from dgl.data.chem import alchemy
from dgl import model_zoo from dgl import model_zoo
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
# from Alchemy_dataset import TencentAlchemyDataset, batcher
def train(model="sch", def train(model="sch",
epochs=80, epochs=80,
device=th.device("cpu"), device=torch.device("cpu"),
training_set_size=0.8): training_set_size=0.8):
print("start") print("start")
alchemy_dataset = alchemy.TencentAlchemyDataset() alchemy_dataset = alchemy.TencentAlchemyDataset()
train_set, test_set = alchemy_dataset.split(train_size=0.8) train_set, test_set = alchemy_dataset.split(train_size=training_set_size)
train_loader = DataLoader(dataset=train_set, train_loader = DataLoader(dataset=train_set,
batch_size=20, batch_size=20,
collate_fn=alchemy.batcher(), collate_fn=alchemy.batcher_dev,
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
test_loader = DataLoader(dataset=test_set, test_loader = DataLoader(dataset=test_set,
batch_size=20, batch_size=20,
collate_fn=alchemy.batcher(), collate_fn=alchemy.batcher_dev,
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
...@@ -42,7 +36,7 @@ def train(model="sch", ...@@ -42,7 +36,7 @@ def train(model="sch",
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
MAE_fn = nn.L1Loss() MAE_fn = nn.L1Loss()
optimizer = th.optim.Adam(model.parameters(), lr=0.0001) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(epochs): for epoch in range(epochs):
...@@ -83,7 +77,6 @@ def train(model="sch", ...@@ -83,7 +77,6 @@ def train(model="sch",
w_mae /= idx + 1 w_mae /= idx + 1
print("MAE (test set): {:.7f}".format(w_mae)) print("MAE (test set): {:.7f}".format(w_mae))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-M", parser.add_argument("-M",
...@@ -95,8 +88,6 @@ if __name__ == "__main__": ...@@ -95,8 +88,6 @@ if __name__ == "__main__":
help="number of epochs", help="number of epochs",
default=250, default=250,
type=int) type=int)
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args() args = parser.parse_args()
assert args.model in ['sch', 'mgcn',
'mpnn'], "model name must be sch, mgcn or mpnn"
train(args.model, args.epochs, device) train(args.model, args.epochs, device)
import datetime import datetime
import dgl import dgl
import numpy as np import numpy as np
import os
import random import random
import torch import torch
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
"""Example dataloader of Tencent Alchemy Dataset """Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/ https://alchemy.tencent.com/
""" """
import numpy as np
import os import os
import os.path as osp import os.path as osp
import zipfile import pathlib
import dgl
from dgl.data.utils import download
import pickle import pickle
import torch import zipfile
from collections import defaultdict from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader import dgl
import numpy as np import dgl.backend as F
import pathlib from dgl.data.utils import download, get_download_dir
from ..utils import get_download_dir, download, _get_dgl_url
from .utils import mol_to_complete_graph
try: try:
import pandas as pd import pandas as pd
from rdkit import Chem from rdkit import Chem
...@@ -23,27 +23,44 @@ try: ...@@ -23,27 +23,44 @@ try:
from rdkit import RDConfig from rdkit import RDConfig
except ImportError: except ImportError:
pass pass
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'} _urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'}
class AlchemyBatcher(object):
"""Data structure for holding a batch of data.
class AlchemyBatcher: Parameters
----------
graph : dgl.BatchedDGLGraph
A batch of DGLGraphs for B molecules
labels : tensor
Labels for B molecules
"""
def __init__(self, graph=None, label=None): def __init__(self, graph=None, label=None):
self.graph = graph self.graph = graph
self.label = label self.label = label
def batcher_dev(batch):
"""Batch datapoints
def batcher(): Parameters
def batcher_dev(batch): ----------
graphs, labels = zip(*batch) batch : list
batch_graphs = dgl.batch(graphs) batch[i][0] gives the DGLGraph for the ith datapoint,
labels = torch.stack(labels, 0) and batch[i][1] gives the label for the ith datapoint.
return AlchemyBatcher(graph=batch_graphs, label=labels)
return batcher_dev
Returns
-------
AlchemyBatcher
An object holding the batch of data
"""
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
labels = F.stack(labels, 0)
class TencentAlchemyDataset(Dataset): return AlchemyBatcher(graph=batch_graphs, label=labels)
class TencentAlchemyDataset(object):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name) chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
...@@ -51,12 +68,15 @@ class TencentAlchemyDataset(Dataset): ...@@ -51,12 +68,15 @@ class TencentAlchemyDataset(Dataset):
"""Featurization for all atoms in a molecule. The atom indices """Featurization for all atoms in a molecule. The atom indices
will be preserved. will be preserved.
Args: Parameters
mol : rdkit.Chem.rdchem.Mol ----------
RDKit molecule object mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns Returns
atom_feats_dict : dict -------
Dictionary for atom features atom_feats_dict : dict
Dictionary for atom features
""" """
atom_feats_dict = defaultdict(list) atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int) is_donor = defaultdict(int)
...@@ -87,7 +107,7 @@ class TencentAlchemyDataset(Dataset): ...@@ -87,7 +107,7 @@ class TencentAlchemyDataset(Dataset):
aromatic = atom.GetIsAromatic() aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization() hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs() num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(torch.FloatTensor(geom[u])) atom_feats_dict['pos'].append(F.tensor(geom[u].astype(np.float32)))
atom_feats_dict['node_type'].append(atom_type) atom_feats_dict['node_type'].append(atom_type)
h_u = [] h_u = []
...@@ -105,27 +125,28 @@ class TencentAlchemyDataset(Dataset): ...@@ -105,27 +125,28 @@ class TencentAlchemyDataset(Dataset):
Chem.rdchem.HybridizationType.SP3) Chem.rdchem.HybridizationType.SP3)
] ]
h_u.append(num_h) h_u.append(num_h)
atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u)) atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'], atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
dim=0) atom_feats_dict['pos'] = F.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0) atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type'] = torch.LongTensor( atom_feats_dict['node_type']).astype(np.int64))
atom_feats_dict['node_type'])
return atom_feats_dict return atom_feats_dict
def alchemy_edges(self, mol, self_loop=True): def alchemy_edges(self, mol, self_loop=False):
"""Featurization for all bonds in a molecule. The bond indices """Featurization for all bonds in a molecule. The bond indices
will be preserved. will be preserved.
Args: Parameters
mol : rdkit.Chem.rdchem.Mol ----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object RDKit molecule object
Returns Returns
bond_feats_dict : dict -------
Dictionary for bond features bond_feats_dict : dict
Dictionary for bond features
""" """
bond_feats_dict = defaultdict(list) bond_feats_dict = defaultdict(list)
...@@ -154,50 +175,15 @@ class TencentAlchemyDataset(Dataset): ...@@ -154,50 +175,15 @@ class TencentAlchemyDataset(Dataset):
bond_feats_dict['distance'].append( bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v])) np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = torch.FloatTensor( bond_feats_dict['e_feat'] = F.tensor(
bond_feats_dict['e_feat']) np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = torch.FloatTensor( bond_feats_dict['distance'] = F.tensor(
bond_feats_dict['distance']).reshape(-1, 1) np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict return bond_feats_dict
def mol_to_dgl(self, mol, self_loop=False):
"""
Convert RDKit molecule object to DGLGraph
Args:
mol: Chem.rdchem.Mol read from sdf
self_loop: Whetaher to add self loop
Returns:
g: DGLGraph
l: related labels
"""
g = dgl.DGLGraph()
# add nodes
num_atoms = mol.GetNumAtoms()
atom_feats = self.alchemy_nodes(mol)
g.add_nodes(num=num_atoms, data=atom_feats)
if self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
bond_feats = self.alchemy_edges(mol, self_loop)
g.edata.update(bond_feats)
return g
def __init__(self, mode='dev', transform=None, from_raw=False): def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
'test'], "mode should be dev/valid/test"
self.mode = mode self.mode = mode
self.transform = transform self.transform = transform
...@@ -224,15 +210,11 @@ class TencentAlchemyDataset(Dataset): ...@@ -224,15 +210,11 @@ class TencentAlchemyDataset(Dataset):
def _load(self): def _load(self):
if self.mode == 'dev': if self.mode == 'dev':
if not self.from_raw: if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
"rb") as f:
self.graphs = pickle.load(f) self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
"rb") as f:
self.labels = pickle.load(f) self.labels = pickle.load(f)
else: else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv") target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv( self.target = pd.read_csv(
target_file, target_file,
...@@ -245,15 +227,15 @@ class TencentAlchemyDataset(Dataset): ...@@ -245,15 +227,15 @@ class TencentAlchemyDataset(Dataset):
]] ]]
self.graphs, self.labels = [], [] self.graphs, self.labels = [], []
sdf_dir = pathlib.Path(self.file_dir, "sdf")
supp = Chem.SDMolSupplier( supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf")) osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0 cnt = 0
for sdf, label in zip(supp, self.target.iterrows()): for sdf, label in zip(supp, self.target.iterrows()):
graph = self.mol_to_dgl(sdf) graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1 cnt += 1
self.graphs.append(graph) self.graphs.append(graph)
label = torch.FloatTensor(label[1].tolist()) label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label) self.labels.append(label)
self.normalize() self.normalize()
...@@ -278,10 +260,7 @@ class TencentAlchemyDataset(Dataset): ...@@ -278,10 +260,7 @@ class TencentAlchemyDataset(Dataset):
return g, l return g, l
def split(self, train_size=0.8): def split(self, train_size=0.8):
""" """Split the dataset into two AlchemySubset for train&test."""
Split the dataset into two AlchemySubset for train&test.
"""
assert 0 < train_size < 1 assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size) train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num], train_set = AlchemySubset(self.graphs[:train_num],
...@@ -292,15 +271,13 @@ class TencentAlchemyDataset(Dataset): ...@@ -292,15 +271,13 @@ class TencentAlchemyDataset(Dataset):
self.transform) self.transform)
return train_set, test_set return train_set, test_set
class AlchemySubset(TencentAlchemyDataset): class AlchemySubset(TencentAlchemyDataset):
""" """
Sub-dataset split from TencentAlchemyDataset. Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set. Used to construct the training & test set.
""" """
def __init__(self, graphs, labels, mean=0, std=1, transform=None): def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__()
self.graphs = graphs self.graphs = graphs
self.labels = labels self.labels = labels
self.mean = mean self.mean = mean
......
...@@ -7,7 +7,7 @@ import pickle ...@@ -7,7 +7,7 @@ import pickle
import sys import sys
from dgl import DGLGraph from dgl import DGLGraph
from .utils import smile2graph from .utils import smile_to_bigraph
class CSVDataset(object): class CSVDataset(object):
...@@ -28,9 +28,9 @@ class CSVDataset(object): ...@@ -28,9 +28,9 @@ class CSVDataset(object):
One column includes smiles and other columns for labels. One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names. Column names other than smiles column would be considered as task names.
smile2graph: callable, str -> DGLGraph smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph. at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smile_column: str smile_column: str
Column name that including smiles Column name that including smiles
...@@ -39,7 +39,7 @@ class CSVDataset(object): ...@@ -39,7 +39,7 @@ class CSVDataset(object):
Path to store the preprocessed data Path to store the preprocessed data
""" """
def __init__(self, df, smile2graph=smile2graph, smile_column='smiles', def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"): cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules: if 'rdkit' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
...@@ -50,9 +50,9 @@ class CSVDataset(object): ...@@ -50,9 +50,9 @@ class CSVDataset(object):
self.task_names = self.df.columns.drop([smile_column]).tolist() self.task_names = self.df.columns.drop([smile_column]).tolist()
self.n_tasks = len(self.task_names) self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path self.cache_file_path = cache_file_path
self._pre_process(smile2graph) self._pre_process(smile_to_graph)
def _pre_process(self, smile2graph): def _pre_process(self, smile_to_graph):
"""Pre-process the dataset """Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs * Convert molecules from smiles format into DGLGraphs
...@@ -66,7 +66,7 @@ class CSVDataset(object): ...@@ -66,7 +66,7 @@ class CSVDataset(object):
with open(self.cache_file_path, 'rb') as f: with open(self.cache_file_path, 'rb') as f:
self.graphs = pickle.load(f) self.graphs = pickle.load(f)
else: else:
self.graphs = [smile2graph(s) for s in self.smiles] self.graphs = [smile_to_graph(s) for s in self.smiles]
with open(self.cache_file_path, 'wb') as f: with open(self.cache_file_path, 'wb') as f:
pickle.dump(self.graphs, f) pickle.dump(self.graphs, f)
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import sys import sys
from .csv_dataset import CSVDataset from .csv_dataset import CSVDataset
from .utils import smile2graph from .utils import smile_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url from ..utils import get_download_dir, download, _get_dgl_url
try: try:
...@@ -32,11 +32,11 @@ class Tox21(CSVDataset): ...@@ -32,11 +32,11 @@ class Tox21(CSVDataset):
Parameters Parameters
---------- ----------
smile2graph: callable, str -> DGLGraph smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph. at python/dgl/data/chem/utils.py named with smile_to_bigraph.
""" """
def __init__(self, smile2graph=smile2graph): def __init__(self, smile_to_graph=smile_to_bigraph):
if 'pandas' not in sys.modules: if 'pandas' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
dgl_warning("Please install pandas") dgl_warning("Please install pandas")
...@@ -48,7 +48,7 @@ class Tox21(CSVDataset): ...@@ -48,7 +48,7 @@ class Tox21(CSVDataset):
df = df.drop(columns=['mol_id']) df = df.drop(columns=['mol_id'])
super().__init__(df, smile2graph, cache_file_path="tox21_dglgraph.pkl") super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.pkl")
self._weight_balancing() self._weight_balancing()
......
import dgl.backend as F import dgl.backend as F
import numpy as np import numpy as np
from functools import partial
from dgl import DGLGraph from dgl import DGLGraph
...@@ -136,8 +137,41 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer): ...@@ -136,8 +137,41 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
return {self.atom_data_field: atom_features} return {self.atom_data_field: atom_features}
def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturizer(), bond_featurizer=None): def mol_to_graph(mol, graph_constructor, atom_featurizer, bond_featurizer):
"""Convert SMILES into a DGLGraph. """Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph. **i** th node in the returned DGLGraph.
...@@ -152,24 +186,23 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi ...@@ -152,24 +186,23 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi
Parameters Parameters
---------- ----------
smiles : str mol : rdkit.Chem.rdchem.Mol
String of SMILES RDKit molecule holder
add_self_loop : bool add_self_loop : bool
Whether to add self loops in DGLGraphs. Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update Returns
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer(). -------
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict g : DGLGraph
Featurization for bonds in a molecule, which can be used to update Empty bigraph topology of the molecule
edata for a DGLGraph.
""" """
mol = Chem.MolFromSmiles(smile)
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = DGLGraph() g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms() num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms) g.add_nodes(num_atoms)
# Add edges
src_list = [] src_list = []
dst_list = [] dst_list = []
num_bonds = mol.GetNumBonds() num_bonds = mol.GetNumBonds()
...@@ -185,11 +218,146 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi ...@@ -185,11 +218,146 @@ def smile2graph(smile, add_self_loop=False, atom_featurizer=CanonicalAtomFeaturi
nodes = g.nodes() nodes = g.nodes()
g.add_edges(nodes, nodes) g.add_edges(nodes, nodes)
# Featurization return g
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None: def mol_to_bigraph(mol, add_self_loop=False,
g.edata.update(bond_featurizer(mol)) atom_featurizer=CanonicalAtomFeaturizer(),
bond_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_bigraph(smile, add_self_loop=False,
atom_featurizer=CanonicalAtomFeaturizer(),
bond_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smile : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
return mol_to_bigraph(mol, add_self_loop, atom_featurizer, bond_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g return g
def mol_to_complete_graph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_complete_graph(smile, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smile : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
return mol_to_complete_graph(mol, add_self_loop, atom_featurizer, bond_featurizer)
...@@ -2,10 +2,16 @@ ...@@ -2,10 +2,16 @@
import torch import torch
from .dgmg import DGMG from .dgmg import DGMG
from .gcn import GCNClassifier from .gcn import GCNClassifier
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .sch import SchNetModel
from ...data.utils import _get_dgl_url, download from ...data.utils import _get_dgl_url, download
URL = { URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth', 'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth', 'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
...@@ -82,6 +88,12 @@ def load_pretrained(model_name, log=True): ...@@ -82,6 +88,12 @@ def load_pretrained(model_name, log=True):
node_hidden_size=128, node_hidden_size=128,
num_prop_rounds=2, num_prop_rounds=2,
dropout=0.2) dropout=0.2)
elif model_name == 'MGCN_Alchemy':
model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
if log: if log:
print('Pretrained model loaded') print('Pretrained model loaded')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment