Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
# -*- coding:utf-8 -*-
"""Tencent Alchemy Dataset https://alchemy.tencent.com/"""
import numpy as np
import os
import os.path as osp
import pandas as pd
import pathlib
import zipfile
from collections import defaultdict
from dgl import backend as F
from dgl.data.utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from ..utils.mol_to_graph import mol_to_complete_graph
from ..utils.featurizers import atom_type_one_hot, atom_hybridization_one_hot, atom_is_aromatic
__all__ = ['TencentAlchemyDataset']
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
atom_type = atom.GetAtomicNum()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u += atom_is_aromatic(atom)
h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3])
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
class TencentAlchemyDataset(object):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgllife.utils.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, mode='dev',
mol_to_graph=mol_to_complete_graph,
node_featurizer=alchemy_nodes,
edge_featurizer=alchemy_edges,
load=True):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
self.mode = mode
# Construct DGLGraphs from raw data or use the preprocessed data
self.load = load
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if load:
file_name = "{}_processed_dgl".format(mode)
else:
file_name = "{}_single_sdf".format(mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)))
self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'r') as f:
smiles_ = f.readlines()
self.smiles = [s.strip() for s in smiles_]
else:
print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "{}_target.csv".format(self.mode))
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_{:d}'.format(x) for x in range(12)])
self.target = self.target[['property_{:d}'.format(x) for x in range(12)]]
self.graphs, self.labels, self.smiles = [], [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
dataset_size = len(self.target)
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles)
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
save_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)), self.graphs,
labels={'labels': F.stack(self.labels, dim=0)})
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'w') as f:
for s in self.smiles:
f.write(s + '\n')
self.set_mean_and_std()
print(len(self.graphs), "loaded!")
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32 and shape (T)
Labels of the datapoint for all tasks.
"""
return self.smiles[item], self.graphs[item], self.labels[item]
def __len__(self):
"""Size for the dataset.
Returns
-------
int
Size for the dataset.
"""
return len(self.graphs)
def set_mean_and_std(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
The mean and std can be fetched later with ``self.mean`` and ``self.std``.
Parameters
----------
mean : float32 tensor of shape (T)
Mean of labels for all tasks.
std : float32 tensor of shape (T)
Std of labels for all tasks.
"""
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
"""Creating datasets from .csv files for molecular property prediction."""
import dgl.backend as F
import numpy as np
import os
from dgl.data.utils import save_graphs, load_graphs
__all__ = ['MoleculeCSVDataset']
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
This is a general class for loading molecular data from :class:`pandas.DataFrame`.
In data pre-processing, we construct a binary mask indicating the existence of labels.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and some other columns include labels.
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name for smiles in ``df``.
cache_file_path: str
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path, task_names=None, load=True, log_every=1000):
self.df = df
self.smiles = self.df[smiles_column].tolist()
if task_names is None:
self.task_names = self.df.columns.drop([smiles_column]).tolist()
else:
self.task_names = task_names
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load, log_every)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load, log_every):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed.
"""
if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...')
self.graphs, label_dict = load_graphs(self.cache_file_path)
self.labels = label_dict['labels']
self.mask = label_dict['mask']
else:
print('Processing dgl graphs from scratch...')
self.graphs = []
for i, s in enumerate(self.smiles):
if (i + 1) % log_every == 0:
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer))
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
self.mask = F.zerocopy_from_numpy((~np.isnan(_label_values)).astype(np.float32))
save_graphs(self.cache_file_path, self.graphs,
labels={'labels': self.labels, 'mask': self.mask})
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32 and shape (T)
Labels of the datapoint for all tasks
Tensor of dtype float32 and shape (T)
Binary masks indicating the existence of labels for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
def __len__(self):
"""Size for the dataset
Returns
-------
int
Size for the dataset
"""
return len(self.smiles)
"""PDBBind dataset processed by MoleculeNet."""
import dgl.backend as F
import numpy as np
import os
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
__all__ = ['PDBBind']
class PDBBind(object):
"""PDBbind dataset processed by MoleculeNet.
The description below is mainly based on
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__.
The PDBBind database consists of experimentally measured binding affinities for
bio-molecular complexes `[2] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15163179%5Buid%5D>`__,
`[3] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15943484%5Buid%5D>`__. It provides detailed
3D Cartesian coordinates of both ligands and their target proteins derived from experimental
(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the
protein-ligand binding geometry. The authors of
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__ use the
"refined" and "core" subsets of the database
`[4] <https://www.ncbi.nlm.nih.gov/pubmed/?term=25301850%5Buid%5D>`__, more carefully
processed for data artifacts, as additional benchmarking targets.
References:
* [1] MoleculeNet: a benchmark for molecular machine learning
* [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures
* [3] The PDBbind database: methodologies and updates
* [4] PDB-wide collection of binding data: current status of the PDBbind database
Parameters
----------
subset : str
In MoleculeNet, we can use either the "refined" subset or the "core" subset. We can
retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping ``self.ligand_mols[i]``,
``self.protein_mols[i]``, ``self.ligand_coordinates[i]`` and
``self.protein_coordinates[i]`` to a DGLHeteroGraph.
Default to :func:`dgllife.utils.ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios. Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
def __init__(self, subset, load_binding_pocket=True, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True,
construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
zero_padding=True, num_processes=64):
self.task_names = ['-logKd/Ki']
self.n_tasks = len(self.task_names)
self._url = 'dataset/pdbbind_v2015.tar.gz'
root_dir_path = get_download_dir()
data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
extracted_data_path = root_dir_path + '/pdbbind_v2015'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
if subset == 'core':
index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
elif subset == 'refined':
index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015'
else:
raise ValueError(
'Expect the subset_choice to be either '
'core or refined, got {}'.format(subset))
self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes)
def _filter_out_invalid(self, ligands_loaded, proteins_loaded, use_conformation):
"""Filter out invalid ligand-protein pairs.
Parameters
----------
ligands_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
proteins_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
use_conformation : bool
Whether we need conformation information (atom coordinates) and filter out molecules
without valid conformation.
"""
num_pairs = len(proteins_loaded)
self.indices, self.ligand_mols, self.protein_mols = [], [], []
if use_conformation:
self.ligand_coordinates, self.protein_coordinates = [], []
else:
# Use None for placeholders.
self.ligand_coordinates = [None for _ in range(num_pairs)]
self.protein_coordinates = [None for _ in range(num_pairs)]
for i in range(num_pairs):
ligand_mol, ligand_coordinates = ligands_loaded[i]
protein_mol, protein_coordinates = proteins_loaded[i]
if (not use_conformation) and all(v is not None for v in [protein_mol, ligand_mol]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.protein_mols.append(protein_mol)
elif all(v is not None for v in [
protein_mol, protein_coordinates, ligand_mol, ligand_coordinates]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.ligand_coordinates.append(ligand_coordinates)
self.protein_mols.append(protein_mol)
self.protein_coordinates.append(protein_coordinates)
def _preprocess(self, root_path, index_label_file, load_binding_pocket,
sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes):
"""Preprocess the dataset.
The pre-processing proceeds as follows:
1. Load the dataset
2. Clean the dataset and filter out invalid pairs
3. Construct graphs
4. Prepare node and edge features
Parameters
----------
root_path : str
Root path for molecule files.
index_label_file : str
Path to the index file for the dataset.
load_binding_pocket : bool
Whether to load binding pockets or full proteins.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system.
"""
contents = []
with open(index_label_file, 'r') as f:
for line in f.readlines():
if line[0] != "#":
splitted_elements = line.split()
if len(splitted_elements) == 8:
# Ignore "//"
contents.append(splitted_elements[:5] + splitted_elements[6:])
else:
print('Incorrect data format.')
print(splitted_elements)
self.df = pd.DataFrame(contents, columns=(
'PDB_code', 'resolution', 'release_year',
'-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name'))
pdbs = self.df['PDB_code'].tolist()
self.ligand_files = [os.path.join(
root_path, 'v2015', pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs]
if load_binding_pocket:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs]
else:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs]
num_processes = min(num_processes, len(pdbs))
print('Loading ligands...')
ligands_loaded = multiprocess_load_molecules(self.ligand_files,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
print('Loading proteins...')
proteins_loaded = multiprocess_load_molecules(self.protein_files,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation)
self.df = self.df.iloc[self.indices]
self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32))
print('Finished cleaning the dataset, '
'got {:d}/{:d} valid pairs'.format(len(self), len(pdbs)))
# Prepare zero padding
if zero_padding:
max_num_ligand_atoms = 0
max_num_protein_atoms = 0
for i in range(len(self)):
max_num_ligand_atoms = max(
max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms())
max_num_protein_atoms = max(
max_num_protein_atoms, self.protein_mols[i].GetNumAtoms())
else:
max_num_ligand_atoms = None
max_num_protein_atoms = None
print('Start constructing graphs and featurizing them.')
self.graphs = []
for i in range(len(self)):
print('Constructing and featurizing datapoint {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(construct_graph_and_featurize(
self.ligand_mols[i], self.protein_mols[i],
self.ligand_coordinates[i], self.protein_coordinates[i],
max_num_ligand_atoms, max_num_protein_atoms))
def __len__(self):
"""Get the size of the dataset.
Returns
-------
int
Number of valid ligand-protein pairs in the dataset.
"""
return len(self.indices)
def __getitem__(self, item):
"""Get the datapoint associated with the index.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
int
Index for the datapoint.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the ligand molecule.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the protein molecule.
DGLHeteroGraph
Pre-processed DGLHeteroGraph with features extracted.
Float32 tensor
Label for the datapoint.
"""
return item, self.ligand_mols[item], self.protein_mols[item], \
self.graphs[item], self.labels[item]
"""Dataset for aromaticity prediction"""
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from .csv_dataset import MoleculeCSVDataset
from ..utils.mol_to_graph import smiles_to_bigraph
__all__ = ['PubChemBioAssayAromaticity']
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, edge_featurizer=None, load=True, log_every=1000):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load, log_every=log_every)
"""The Toxicology in the 21st Century initiative."""
import dgl.backend as F
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from .csv_dataset import MoleculeCSVDataset
from ..utils.mol_to_graph import smiles_to_bigraph
__all__ = ['Tox21']
class Tox21(MoleculeCSVDataset):
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
load=True,
log_every=1000):
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
self.id = df['mol_id']
df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin",
load=load, log_every=log_every)
self._weight_balancing()
def _weight_balancing(self):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos = F.sum(self.labels, dim=0)
num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos
@property
def task_pos_weights(self):
"""Get weights for positive samples on each task
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
Returns
-------
Tensor of dtype float32 and shape (T)
Weight of positive samples on all tasks
"""
return self._task_pos_weights
"""USPTO for reaction prediction"""
import errno
import numpy as np
import os
import random
import torch
from collections import defaultdict
from copy import deepcopy
from dgl import DGLGraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive, \
save_graphs, load_graphs
from functools import partial
from itertools import combinations
from multiprocessing import Pool
from rdkit import Chem, RDLogger
from rdkit.Chem import rdmolops
from tqdm import tqdm
from ..utils.featurizers import BaseAtomFeaturizer, ConcatFeaturizer, one_hot_encoding, \
atom_type_one_hot, atom_degree_one_hot, atom_explicit_valence_one_hot, \
atom_implicit_valence_one_hot, atom_is_aromatic, atom_formal_charge_one_hot, \
BaseBondFeaturizer, bond_type_one_hot, bond_is_conjugated, bond_is_in_ring
from ..utils.mol_to_graph import mol_to_bigraph, mol_to_complete_graph
__all__ = ['WLNCenterDataset',
'USPTOCenter',
'WLNRankDataset',
'USPTORank']
# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')
# Atom types distinguished in featurization
atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co',
'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi',
'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs']
default_node_featurizer_center = BaseAtomFeaturizer({
'hv': ConcatFeaturizer(
[partial(atom_type_one_hot,
allowable_set=atom_types, encode_unknown=True),
partial(atom_degree_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
partial(atom_explicit_valence_one_hot,
allowable_set=list(range(1, 6)), encode_unknown=True),
partial(atom_implicit_valence_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
atom_is_aromatic]
)
})
default_node_featurizer_rank = BaseAtomFeaturizer({
'hv': ConcatFeaturizer(
[partial(atom_type_one_hot,
allowable_set=atom_types, encode_unknown=True),
partial(atom_formal_charge_one_hot,
allowable_set=[-3, -2, -1, 0, 1, 2], encode_unknown=True),
partial(atom_degree_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
partial(atom_explicit_valence_one_hot,
allowable_set=list(range(1, 6)), encode_unknown=True),
partial(atom_implicit_valence_one_hot,
allowable_set=list(range(5)), encode_unknown=True),
atom_is_aromatic]
)
})
default_edge_featurizer_center = BaseBondFeaturizer({
'he': ConcatFeaturizer([
bond_type_one_hot, bond_is_conjugated, bond_is_in_ring]
)
})
default_edge_featurizer_rank = BaseBondFeaturizer({
'he': ConcatFeaturizer([
bond_type_one_hot, bond_is_in_ring]
)
})
def default_atom_pair_featurizer(reactants):
"""Featurize each pair of atoms, which will be used in updating
the edata of a complete DGLGraph.
The features include the bond type between the atoms (if any) and whether
they belong to the same molecule. It is used in the global attention mechanism.
Parameters
----------
reactants : str
SMILES for reactants
data_field : str
Key for storing the features in DGLGraph.edata. Default to 'atom_pair'
Returns
-------
float32 tensor of shape (V^2, 10)
features for each pair of atoms.
"""
# Decide the reactant membership for each atom
atom_to_reactant = dict()
reactant_list = reactants.split('.')
for id, s in enumerate(reactant_list):
mol = Chem.MolFromSmiles(s)
for atom in mol.GetAtoms():
atom_to_reactant[atom.GetIntProp('molAtomMapNumber') - 1] = id
# Construct mapping from atom pair to RDKit bond object
all_reactant_mol = Chem.MolFromSmiles(reactants)
atom_pair_to_bond = dict()
for bond in all_reactant_mol.GetBonds():
atom1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1
atom2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1
atom_pair_to_bond[(atom1, atom2)] = bond
atom_pair_to_bond[(atom2, atom1)] = bond
def _featurize_a_bond(bond):
return bond_type_one_hot(bond) + bond_is_conjugated(bond) + bond_is_in_ring(bond)
features = []
num_atoms = all_reactant_mol.GetNumAtoms()
for i in range(num_atoms):
for j in range(num_atoms):
pair_feature = np.zeros(10)
if i == j:
features.append(pair_feature)
continue
bond = atom_pair_to_bond.get((i, j), None)
if bond is not None:
pair_feature[1:7] = _featurize_a_bond(bond)
else:
pair_feature[0] = 1.
pair_feature[-4] = 1. if atom_to_reactant[i] != atom_to_reactant[j] else 0.
pair_feature[-3] = 1. if atom_to_reactant[i] == atom_to_reactant[j] else 0.
pair_feature[-2] = 1. if len(reactant_list) == 1 else 0.
pair_feature[-1] = 1. if len(reactant_list) > 1 else 0.
features.append(pair_feature)
return torch.from_numpy(np.stack(features, axis=0).astype(np.float32))
def get_pair_label(reactants_mol, graph_edits):
"""Construct labels for each pair of atoms in reaction center prediction
Parameters
----------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for all reactants in a reaction
graph_edits : str
Specifying which pairs of atoms loss a bond or form a particular bond in the reaction
Returns
-------
float32 tensor of shape (V^2, 5)
Labels constructed. V for the number of atoms in the reactants.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
pair_to_changes = defaultdict(list)
for edit in graph_edits.split(';'):
a1, a2, change = edit.split('-')
atom1 = int(a1) - 1
atom2 = int(a2) - 1
change = bond_change_to_id[float(change)]
pair_to_changes[(atom1, atom2)].append(change)
pair_to_changes[(atom2, atom1)].append(change)
num_atoms = reactants_mol.GetNumAtoms()
labels = torch.zeros((num_atoms, num_atoms, 5))
for pair in pair_to_changes.keys():
i, j = pair
labels[i, j, pair_to_changes[(j, i)]] = 1.
return labels.reshape(-1, 5)
def get_bond_changes(reaction):
"""Get the bond changes in a reaction.
Parameters
----------
reaction : str
SMILES for a reaction, e.g. [CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7]
(=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5]
[c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]. It consists of reactants,
products and the atom mapping.
Returns
-------
bond_changes : set of 3-tuples
Each tuple consists of (atom1, atom2, change type)
There are 5 possible values for change type. 0 for losing the bond, and 1, 2, 3, 1.5
separately for forming a single, double, triple or aromatic bond.
"""
reactants = Chem.MolFromSmiles(reaction.split('>')[0])
products = Chem.MolFromSmiles(reaction.split('>')[2])
conserved_maps = [
a.GetProp('molAtomMapNumber')
for a in products.GetAtoms() if a.HasProp('molAtomMapNumber')]
bond_changes = set() # keep track of bond changes
# Look at changed bonds
bonds_prev = {}
for bond in reactants.GetBonds():
nums = sorted(
[bond.GetBeginAtom().GetProp('molAtomMapNumber'),
bond.GetEndAtom().GetProp('molAtomMapNumber')])
if (nums[0] not in conserved_maps) and (nums[1] not in conserved_maps):
continue
bonds_prev['{}~{}'.format(nums[0], nums[1])] = bond.GetBondTypeAsDouble()
bonds_new = {}
for bond in products.GetBonds():
nums = sorted(
[bond.GetBeginAtom().GetProp('molAtomMapNumber'),
bond.GetEndAtom().GetProp('molAtomMapNumber')])
bonds_new['{}~{}'.format(nums[0], nums[1])] = bond.GetBondTypeAsDouble()
for bond in bonds_prev:
if bond not in bonds_new:
# lost bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], 0.0))
else:
if bonds_prev[bond] != bonds_new[bond]:
# changed bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], bonds_new[bond]))
for bond in bonds_new:
if bond not in bonds_prev:
# new bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], bonds_new[bond]))
return bond_changes
def process_line(line):
"""Process one line consisting of one reaction for working with WLN.
Parameters
----------
line : str
One reaction in one line
Returns
-------
formatted_reaction : str
Formatted reaction
"""
reaction = line.strip()
bond_changes = get_bond_changes(reaction)
formatted_reaction = '{} {}\n'.format(
reaction, ';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes]))
return formatted_reaction
def process_file(path, num_processes=1):
"""Pre-process a file of reactions for working with WLN.
Parameters
----------
path : str
Path to the file of reactions
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
with open(path, 'r') as input_file:
lines = input_file.readlines()
if num_processes == 1:
results = []
for li in lines:
results.append(process_line(li))
else:
with Pool(processes=num_processes) as pool:
results = pool.map(process_line, lines)
with open(path + '.proc', 'w') as output_file:
for line in results:
output_file.write(line)
print('Finished processing {}'.format(path))
def load_one_reaction(line):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
reactants are not valid.
reaction : str
Reaction. None will be returned if the reactants are not valid.
graph_edits : str
Graph edits associated with the reaction. None will be returned if the
reactants are not valid.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants = reaction.split('>')[0]
mol = Chem.MolFromSmiles(reactants)
if mol is None:
return None, None, None
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(mol.GetNumAtoms())]
for j in range(mol.GetNumAtoms()):
atom = mol.GetAtomWithIdx(j)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = j
mol = rdmolops.RenumberAtoms(mol, atom_map_order)
return mol, reaction, graph_edits
class WLNCenterDataset(object):
"""Dataset for reaction center prediction with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
We will check if raw_file_path + '.proc' exists, where each line has the reaction
SMILES and the corresponding graph edits. If not, we will preprocess
the raw reaction file.
mol_graph_path : str
Path to save/load DGLGraphs for molecules.
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
raw_file_path,
mol_graph_path,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer_center,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True,
num_processes=1):
super(WLNCenterDataset, self).__init__()
self._atom_pair_featurizer = atom_pair_featurizer
self.atom_pair_features = []
self.atom_pair_labels = []
# Map number of nodes to a corresponding complete graph
self.complete_graphs = dict()
path_to_reaction_file = raw_file_path + '.proc'
if not os.path.isfile(path_to_reaction_file):
print('Pre-processing graph edits from reaction data')
process_file(raw_file_path, num_processes)
import time
t0 = time.time()
full_mols, full_reactions, full_graph_edits = \
self.load_reaction_data(path_to_reaction_file, num_processes)
print('Time spent', time.time() - t0)
if load and os.path.isfile(mol_graph_path):
print('Loading previously saved graphs...')
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
print('Constructing graphs from scratch...')
if num_processes == 1:
self.reactant_mol_graphs = []
for mol in full_mols:
self.reactant_mol_graphs.append(mol_to_graph(
mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False))
else:
torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(processes=num_processes) as pool:
self.reactant_mol_graphs = pool.map(
partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False),
full_mols)
save_graphs(mol_graph_path, self.reactant_mol_graphs)
self.mols = full_mols
self.reactions = full_reactions
self.graph_edits = full_graph_edits
self.atom_pair_features.extend([None for _ in range(len(self.mols))])
self.atom_pair_labels.extend([None for _ in range(len(self.mols))])
def load_reaction_data(self, file_path, num_processes):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
all_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances
all_reactions : list of str
Reactions
all_graph_edits : list of str
Graph edits in the reactions.
"""
all_mols = []
all_reactions = []
all_graph_edits = []
with open(file_path, 'r') as f:
lines = f.readlines()
if num_processes == 1:
results = []
for li in lines:
mol, reaction, graph_edits = load_one_reaction(li)
results.append((mol, reaction, graph_edits))
else:
with Pool(processes=num_processes) as pool:
results = pool.map(load_one_reaction, lines)
for mol, reaction, graph_edits in results:
if mol is None:
continue
all_mols.append(mol)
all_reactions.append(reaction)
all_graph_edits.append(graph_edits)
return all_mols, all_reactions, all_graph_edits
def __len__(self):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
return len(self.mols)
def __getitem__(self, item):
"""Get the i-th datapoint.
Returns
-------
str
Reaction
str
Graph edits for the reaction
DGLGraph
DGLGraph for the ith molecular graph
DGLGraph
Complete DGLGraph, which will be needed for predicting
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
float32 tensor of shape (V^2, 5)
Labels for reaction center prediction.
V for the number of atoms in the reactants.
"""
mol = self.mols[item]
num_atoms = mol.GetNumAtoms()
if num_atoms not in self.complete_graphs:
self.complete_graphs[num_atoms] = mol_to_complete_graph(
mol, add_self_loop=True, canonical_atom_order=False)
if self.atom_pair_features[item] is None:
reactants = self.reactions[item].split('>')[0]
self.atom_pair_features[item] = self._atom_pair_featurizer(reactants)
if self.atom_pair_labels[item] is None:
self.atom_pair_labels[item] = get_pair_label(mol, self.graph_edits[item])
return self.reactions[item], self.graph_edits[item], \
self.reactant_mol_graphs[item], \
self.complete_graphs[num_atoms], \
self.atom_pair_features[item], \
self.atom_pair_labels[item]
class USPTOCenter(WLNCenterDataset):
"""USPTO dataset for reaction center prediction.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
subset,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer_center,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True,
num_processes=1):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO for reaction center prediction.'.format(subset))
self._subset = subset
if subset == 'val':
subset = 'valid'
self._url = 'dataset/uspto.zip'
data_path = get_download_dir() + '/uspto.zip'
extracted_data_path = get_download_dir() + '/uspto'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
super(USPTOCenter, self).__init__(
raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
mol_graph_path=extracted_data_path + '/{}_mol_graphs.bin'.format(subset),
mol_to_graph=mol_to_graph,
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
atom_pair_featurizer=atom_pair_featurizer,
load=load,
num_processes=num_processes)
@property
def subset(self):
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return self._subset
def mkdir_p(path):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def load_one_reaction_rank(line):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
line is not valid.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product. None will be returned if the line is not valid.
reaction_real_bond_changes : list of 3-tuples
Real bond changes in the reaction. Each tuple is of form (atom1, atom2, change_type). For
change_type, 0.0 stands for losing a bond, 1.0, 2.0, 3.0 and 1.5 separately stands for
forming a single, double, triple or aromatic bond.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants, _, product = reaction.split('>')
reactants_mol = Chem.MolFromSmiles(reactants)
if reactants_mol is None:
return None, None, None, None, None
product_mol = Chem.MolFromSmiles(product)
if product_mol is None:
return None, None, None, None, None
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(reactants_mol.GetNumAtoms())]
for j in range(reactants_mol.GetNumAtoms()):
atom = reactants_mol.GetAtomWithIdx(j)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = j
reactants_mol = rdmolops.RenumberAtoms(reactants_mol, atom_map_order)
reaction_real_bond_changes = []
for changed_bond in graph_edits.split(';'):
atom1, atom2, change_type = changed_bond.split('-')
atom1, atom2 = int(atom1) - 1, int(atom2) - 1
reaction_real_bond_changes.append(
(min(atom1, atom2), max(atom1, atom2), float(change_type)))
return reactants_mol, product_mol, reaction_real_bond_changes
def load_candidate_bond_changes_for_one_reaction(line):
"""Load candidate bond changes for a reaction
Parameters
----------
line : str
Candidate bond changes separated by ;. Each candidate bond change takes the
form of atom1, atom2, change_type and change_score.
Returns
-------
list of 4-tuples
Loaded candidate bond changes.
"""
reaction_candidate_bond_changes = []
elements = line.strip().split(';')[:-1]
for candidate in elements:
atom1, atom2, change_type, score = candidate.split(' ')
atom1, atom2 = int(atom1) - 1, int(atom2) - 1
reaction_candidate_bond_changes.append((
min(atom1, atom2), max(atom1, atom2), float(change_type), float(score)))
return reaction_candidate_bond_changes
def bookkeep_reactant(mol, candidate_pairs):
"""Bookkeep reaction-related information of reactants.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants.
candidate_pairs : list of 2-tuples
Pairs of atoms that ranked high by a model for reaction center prediction.
By assumption, the two atoms are different and the first atom has a smaller
index than the second.
Returns
-------
info : dict
Reaction-related information of reactants
"""
num_atoms = mol.GetNumAtoms()
info = {
# free valence of atoms
'free_val': [0 for _ in range(num_atoms)],
# Whether it is a carbon atom
'is_c': [False for _ in range(num_atoms)],
# Whether it is a carbon atom connected to a nitrogen atom in pyridine
'is_c2_of_pyridine': [False for _ in range(num_atoms)],
# Whether it is a phosphorous atom
'is_p': [False for _ in range(num_atoms)],
# Whether it is a sulfur atom
'is_s': [False for _ in range(num_atoms)],
# Whether it is an oxygen atom
'is_o': [False for _ in range(num_atoms)],
# Whether it is a nitrogen atom
'is_n': [False for _ in range(num_atoms)],
'pair_to_bond_val': dict(),
'ring_bonds': set()
}
# bookkeep atoms
for j, atom in enumerate(mol.GetAtoms()):
info['free_val'][j] += atom.GetTotalNumHs() + abs(atom.GetFormalCharge())
# An aromatic carbon atom next to an aromatic nitrogen atom can get a
# carbonyl b/c of bookkeeping of hydroxypyridines
if atom.GetSymbol() == 'C':
info['is_c'][j] = True
if atom.GetIsAromatic():
for nbr in atom.GetNeighbors():
if nbr.GetSymbol() == 'N' and nbr.GetDegree() == 2:
info['is_c2_of_pyridine'][j] = True
break
# A nitrogen atom should be allowed to become positively charged
elif atom.GetSymbol() == 'N':
info['free_val'][j] += 1 - atom.GetFormalCharge()
info['is_n'][j] = True
# Phosphorous atoms can form a phosphonium
elif atom.GetSymbol() == 'P':
info['free_val'][j] += 1 - atom.GetFormalCharge()
info['is_p'][j] = True
elif atom.GetSymbol() == 'O':
info['is_o'][j] = True
elif atom.GetSymbol() == 'S':
info['is_s'][j] = True
# bookkeep bonds
for bond in mol.GetBonds():
atom1, atom2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
atom1, atom2 = min(atom1, atom2), max(atom1, atom2)
type_val = bond.GetBondTypeAsDouble()
info['pair_to_bond_val'][(atom1, atom2)] = type_val
if (atom1, atom2) in candidate_pairs:
info['free_val'][atom1] += type_val
info['free_val'][atom2] += type_val
if bond.IsInRing():
info['ring_bonds'].add((atom1, atom2))
return info
def bookkeep_product(mol):
"""Bookkeep reaction-related information of atoms/bonds in products
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for products.
Returns
-------
info : dict
Reaction-related information of atoms/bonds in products
"""
info = {
'atoms': set()
}
for atom in mol.GetAtoms():
info['atoms'].add(atom.GetAtomMapNum() - 1)
return info
def is_connected_change_combo(combo_ids, cand_change_adj):
"""Check whether the combo of bond changes yields a connected component.
Parameters
----------
combo_ids : tuple of int
Ids for bond changes in the combination.
cand_change_adj : bool ndarray of shape (N, N)
Adjacency matrix for candidate bond changes. Two candidate bond
changes are considered adjacent if they share a common atom.
* N for the number of candidate bond changes.
Returns
-------
bool
Whether the combo of bond changes yields a connected component
"""
if len(combo_ids) == 1:
return True
multi_hop_adj = np.linalg.matrix_power(
cand_change_adj[combo_ids, :][:, combo_ids], len(combo_ids) - 1)
# The combo is connected if the distance between
# any pair of bond changes is within len(combo) - 1
return np.all(multi_hop_adj)
def is_valid_combo(combo_changes, reactant_info):
"""Whether the combo of bond changes is chemically valid.
Parameters
----------
combo_changes : list of 4-tuples
Each tuple consists of atom1, atom2, type of bond change (in the form of related
valence) and score for the change.
reactant_info : dict
Reaction-related information of reactants
Returns
-------
bool
Whether the combo of bond changes is chemically valid.
"""
num_atoms = len(reactant_info['free_val'])
force_even_parity = np.zeros((num_atoms,), dtype=bool)
force_odd_parity = np.zeros((num_atoms,), dtype=bool)
pair_seen = defaultdict(bool)
free_val_tmp = reactant_info['free_val'].copy()
for (atom1, atom2, change_type, score) in combo_changes:
if pair_seen[(atom1, atom2)]:
# A pair of atoms cannot have two types of changes. Even if we
# randomly pick one, that will be reduced to a combo of less changes
return False
pair_seen[(atom1, atom2)] = True
# Special valence rules
atom1_type_val = atom2_type_val = change_type
if change_type == 2:
# to form a double bond
if reactant_info['is_o'][atom1]:
if reactant_info['is_c2_of_pyridine'][atom2]:
atom2_type_val = 1.
elif reactant_info['is_p'][atom2]:
# don't count information of =o toward valence
# but require odd valence parity
atom2_type_val = 0.
force_odd_parity[atom2] = True
elif reactant_info['is_s'][atom2]:
atom2_type_val = 0.
force_even_parity[atom2] = True
elif reactant_info['is_o'][atom2]:
if reactant_info['is_c2_of_pyridine'][atom1]:
atom1_type_val = 1.
elif reactant_info['is_p'][atom1]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_s'][atom1]:
atom1_type_val = 0.
force_even_parity[atom1] = True
elif reactant_info['is_n'][atom1] and reactant_info['is_p'][atom2]:
atom2_type_val = 0.
force_odd_parity[atom2] = True
elif reactant_info['is_n'][atom2] and reactant_info['is_p'][atom1]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_p'][atom1] and reactant_info['is_c'][atom2]:
atom1_type_val = 0.
force_odd_parity[atom1] = True
elif reactant_info['is_p'][atom2] and reactant_info['is_c'][atom1]:
atom2_type_val = 0.
force_odd_parity[atom2] = True
reactant_pair_val = reactant_info['pair_to_bond_val'].get((atom1, atom2), None)
if reactant_pair_val is not None:
free_val_tmp[atom1] += reactant_pair_val - atom1_type_val
free_val_tmp[atom2] += reactant_pair_val - atom2_type_val
else:
free_val_tmp[atom1] -= atom1_type_val
free_val_tmp[atom2] -= atom2_type_val
free_val_tmp = np.array(free_val_tmp)
# False if 1) too many connections 2) sulfur valence not even
# 3) phosphorous valence not odd
if any(free_val_tmp < 0) or \
any(aval % 2 != 0 for aval in free_val_tmp[force_even_parity]) or \
any(aval % 2 != 1 for aval in free_val_tmp[force_odd_parity]):
return False
return True
def edit_mol(reactant_mols, edits, product_info):
"""Simulate reaction via graph editing
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
bond_change_to_type = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE,
3: Chem.rdchem.BondType.TRIPLE, 1.5: Chem.rdchem.BondType.AROMATIC}
new_mol = Chem.RWMol(reactant_mols)
[atom.SetNumExplicitHs(0) for atom in new_mol.GetAtoms()]
for atom1, atom2, change_type, score in edits:
bond = new_mol.GetBondBetweenAtoms(atom1, atom2)
if bond is not None:
new_mol.RemoveBond(atom1, atom2)
if change_type > 0:
new_mol.AddBond(atom1, atom2, bond_change_to_type[change_type])
pred_mol = new_mol.GetMol()
pred_smiles = Chem.MolToSmiles(pred_mol)
pred_list = pred_smiles.split('.')
pred_mols = []
for pred_smiles in pred_list:
mol = Chem.MolFromSmiles(pred_smiles)
if mol is None:
continue
atom_set = set([atom.GetAtomMapNum() - 1 for atom in mol.GetAtoms()])
if len(atom_set & product_info['atoms']) == 0:
continue
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
pred_mols.append(mol)
return '.'.join(sorted([Chem.MolToSmiles(mol) for mol in pred_mols]))
def get_product_smiles(reactant_mols, edits, product_info):
"""Get the product smiles of the reaction
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
smiles = edit_mol(reactant_mols, edits, product_info)
if len(smiles) != 0:
return smiles
try:
Chem.Kekulize(reactant_mols)
except Exception as e:
return smiles
return edit_mol(reactant_mols, edits, product_info)
def generate_valid_candidate_combos():
return NotImplementedError
def pre_process_one_reaction(info, num_candidate_bond_changes, max_num_bond_changes,
max_num_change_combos, mode):
"""Pre-process one reaction for candidate ranking.
Parameters
----------
info : 4-tuple
* candidate_bond_changes : list of tuples
The candidate bond changes for the reaction
* real_bond_changes : list of tuples
The real bond changes for the reaction
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants
* product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for product
num_candidate_bond_changes : int
Number of candidate bond changes to consider for the ground truth reaction.
max_num_bond_changes : int
Maximum number of bond changes per reaction.
max_num_change_combos : int
Number of bond change combos to consider for each reaction.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
mode : str
Whether the dataset is to be used for training, validation or test.
Returns
-------
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for combos.
reactant_info : dict
Reaction-related information of reactants.
"""
assert mode in ['train', 'val', 'test'], \
"Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
candidate_bond_changes_, real_bond_changes, reactant_mol, product_mol = info
candidate_pairs = [(atom1, atom2) for (atom1, atom2, _, _)
in candidate_bond_changes_]
reactant_info = bookkeep_reactant(reactant_mol, candidate_pairs)
if mode == 'train':
product_info = bookkeep_product(product_mol)
# Filter out candidate new bonds already in reactants
candidate_bond_changes = []
count = 0
for (atom1, atom2, change_type, score) in candidate_bond_changes_:
if ((atom1, atom2) not in reactant_info['pair_to_bond_val']) or \
(reactant_info['pair_to_bond_val'][(atom1, atom2)] != change_type):
candidate_bond_changes.append((atom1, atom2, change_type, score))
count += 1
if count == num_candidate_bond_changes:
break
# Check if two bond changes have atom in common
cand_change_adj = np.eye(len(candidate_bond_changes), dtype=bool)
for i in range(len(candidate_bond_changes)):
atom1_1, atom1_2, _, _ = candidate_bond_changes[i]
for j in range(i + 1, len(candidate_bond_changes)):
atom2_1, atom2_2, _, _ = candidate_bond_changes[j]
if atom1_1 == atom2_1 or atom1_1 == atom2_2 or \
atom1_2 == atom2_1 or atom1_2 == atom2_2:
cand_change_adj[i, j] = cand_change_adj[j, i] = True
# Enumerate combinations of k candidate bond changes and record
# those that are connected and chemically valid
valid_candidate_combos = []
cand_change_ids = range(len(candidate_bond_changes))
for k in range(1, max_num_bond_changes + 1):
for combo_ids in combinations(cand_change_ids, k):
# Check if the changed bonds form a connected component
if not is_connected_change_combo(combo_ids, cand_change_adj):
continue
combo_changes = [candidate_bond_changes[j] for j in combo_ids]
# Check if the combo is chemically valid
if is_valid_combo(combo_changes, reactant_info):
valid_candidate_combos.append(combo_changes)
if mode == 'train':
random.shuffle(valid_candidate_combos)
# Index for the combo of candidate bond changes
# that is equivalent to the gold combo
real_combo_id = -1
for j, combo_changes in enumerate(valid_candidate_combos):
if set([(atom1, atom2, change_type) for
(atom1, atom2, change_type, score) in combo_changes]) == \
set(real_bond_changes):
real_combo_id = j
break
# If we fail to find the real combo, make it the first entry
if real_combo_id == -1:
valid_candidate_combos = \
[[(atom1, atom2, change_type, 0.0)
for (atom1, atom2, change_type) in real_bond_changes]] + \
valid_candidate_combos
else:
valid_candidate_combos[0], valid_candidate_combos[real_combo_id] = \
valid_candidate_combos[real_combo_id], valid_candidate_combos[0]
product_smiles = get_product_smiles(
reactant_mol, valid_candidate_combos[0], product_info)
if len(product_smiles) > 0:
# Remove combos yielding duplicate products
product_smiles = set([product_smiles])
new_candidate_combos = [valid_candidate_combos[0]]
count = 0
for combo in valid_candidate_combos[1:]:
smiles = get_product_smiles(reactant_mol, combo, product_info)
if smiles in product_smiles or len(smiles) == 0:
continue
product_smiles.add(smiles)
new_candidate_combos.append(combo)
count += 1
if count == max_num_change_combos:
break
valid_candidate_combos = new_candidate_combos
valid_candidate_combos = valid_candidate_combos[:max_num_change_combos]
return valid_candidate_combos, candidate_bond_changes, reactant_info
def featurize_nodes_and_compute_combo_scores(
node_featurizer, reactant_mol, valid_candidate_combos):
"""Featurize atoms in reactants and compute scores for combos of bond changes
Parameters
----------
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
Returns
-------
node_feats : float32 tensor of shape (N, M)
Node features for reactants, N for the number of nodes and M for the feature size
combo_bias : float32 tensor of shape (B, 1)
Scores for combos of bond changes, B equals len(valid_candidate_combos)
"""
node_feats = node_featurizer(reactant_mol)['hv']
combo_bias = torch.zeros(len(valid_candidate_combos), 1).float()
for combo_id, combo in enumerate(valid_candidate_combos):
combo_bias[combo_id] = sum([
score for (atom1, atom2, change_type, score) in combo])
return node_feats, combo_bias
def construct_graphs_rank(info, edge_featurizer):
"""Construct graphs for reactants and candidate products in a reaction and featurize
their edges
Parameters
----------
info : 4-tuple
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
* candidate_combos : list
candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
* candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for candidate products
* reactant_info : dict
Reaction-related information of reactants.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
reaction_graphs : list of DGLGraphs
DGLGraphs for reactants and candidate products with edge features in edata['he'],
where the first graph is for reactants.
"""
reactant_mol, candidate_combos, candidate_bond_changes, reactant_info = info
# Graphs for reactants and candidate products
reaction_graphs = []
# Get graph for the reactants
reactant_graph = mol_to_bigraph(reactant_mol, edge_featurizer=edge_featurizer,
canonical_atom_order=False)
reaction_graphs.append(reactant_graph)
candidate_bond_changes_no_score = [
(atom1, atom2, change_type)
for (atom1, atom2, change_type, score) in candidate_bond_changes]
# Prepare common components across all candidate products
breaking_reactant_neighbors = []
common_src_list = []
common_dst_list = []
common_edge_feats = []
num_bonds = reactant_mol.GetNumBonds()
for j in range(num_bonds):
bond = reactant_mol.GetBondWithIdx(j)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
u_sort, v_sort = min(u, v), max(u, v)
# Whether a bond in reactants might get broken
if (u_sort, v_sort, 0.0) not in candidate_bond_changes_no_score:
common_src_list.extend([u, v])
common_dst_list.extend([v, u])
common_edge_feats.extend([reactant_graph.edata['he'][2 * j],
reactant_graph.edata['he'][2 * j + 1]])
else:
breaking_reactant_neighbors.append((
u_sort, v_sort, bond.GetBondTypeAsDouble()))
for combo in candidate_combos:
combo_src_list = deepcopy(common_src_list)
combo_dst_list = deepcopy(common_dst_list)
combo_edge_feats = deepcopy(common_edge_feats)
candidate_bond_end_atoms = [
(atom1, atom2) for (atom1, atom2, change_type, score) in combo]
for (atom1, atom2, change_type) in breaking_reactant_neighbors:
if (atom1, atom2) not in candidate_bond_end_atoms:
# If a bond might be broken in some other combos but not this,
# add it as a negative sample
combo.append((atom1, atom2, change_type, 0.0))
for (atom1, atom2, change_type, score) in combo:
if change_type == 0:
continue
combo_src_list.extend([atom1, atom2])
combo_dst_list.extend([atom2, atom1])
feats = one_hot_encoding(change_type, [1.0, 2.0, 3.0, 1.5, -1])
if (atom1, atom2) in reactant_info['ring_bonds']:
feats[-1] = 1
feats = torch.tensor(feats).float()
combo_edge_feats.extend([feats, feats.clone()])
combo_edge_feats = torch.stack(combo_edge_feats, dim=0)
combo_graph = DGLGraph()
combo_graph.add_nodes(reactant_graph.number_of_nodes())
combo_graph.add_edges(combo_src_list, combo_dst_list)
combo_graph.edata['he'] = combo_edge_feats
reaction_graphs.append(combo_graph)
return reaction_graphs
class WLNRankDataset(object):
"""Dataset for ranking candidate products with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
mode : str
'train', 'val', or 'test', indicating whether the dataset is used for training,
validation or test.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom formal charge, atom degree, atom explicit valence, atom implicit valence,
aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type
and whether bond is in ring.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
raw_file_path,
candidate_bond_path,
mode,
node_featurizer=default_node_featurizer_rank,
edge_featurizer=default_edge_featurizer_rank,
size_cutoff=100,
max_num_changes_per_reaction=5,
num_candidate_bond_changes=16,
max_num_change_combos_per_reaction=150,
num_processes=1):
super(WLNRankDataset, self).__init__()
assert mode in ['train', 'val', 'test'], \
"Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
self.mode = mode
self.ignore_large_samples = False
self.size_cutoff = size_cutoff
path_to_reaction_file = raw_file_path + '.proc'
if not os.path.isfile(path_to_reaction_file):
print('Pre-processing graph edits from reaction data')
process_file(raw_file_path, num_processes)
self.reactant_mols, self.product_mols, self.real_bond_changes, \
self.ids_for_small_samples = self.load_reaction_data(path_to_reaction_file, num_processes)
self.candidate_bond_changes = self.load_candidate_bond_changes(candidate_bond_path)
self.num_candidate_bond_changes = num_candidate_bond_changes
self.max_num_changes_per_reaction = max_num_changes_per_reaction
self.max_num_change_combos_per_reaction = max_num_change_combos_per_reaction
self.node_featurizer = node_featurizer
self.edge_featurizer = edge_featurizer
def load_reaction_data(self, file_path, num_processes):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
all_reactant_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
all_product_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for products if the dataset is for training and
None otherwise.
all_real_bond_changes : list of list
``all_real_bond_changes[i]`` gives a list of tuples, which are ground
truth bond changes for a reaction.
ids_for_small_samples : list of int
Indices for reactions whose reactants do not contain too many atoms
"""
print('Stage 1/2: loading reaction data...')
all_reactant_mols = []
all_product_mols = []
all_real_bond_changes = []
ids_for_small_samples = []
with open(file_path, 'r') as f:
lines = f.readlines()
def _update_from_line(id, loaded_result):
reactants_mol, product_mol, reaction_real_bond_changes = loaded_result
if reactants_mol is None:
return
all_product_mols.append(product_mol)
all_reactant_mols.append(reactants_mol)
all_real_bond_changes.append(reaction_real_bond_changes)
if reactants_mol.GetNumAtoms() <= self.size_cutoff:
ids_for_small_samples.append(id)
if num_processes == 1:
for id, li in enumerate(tqdm(lines)):
loaded_line = load_one_reaction_rank(li)
_update_from_line(id, loaded_line)
else:
with Pool(processes=num_processes) as pool:
results = pool.map(
load_one_reaction_rank,
lines, chunksize=len(lines) // num_processes)
for id in range(len(lines)):
_update_from_line(id, results[id])
return all_reactant_mols, all_product_mols, all_real_bond_changes, ids_for_small_samples
def load_candidate_bond_changes(self, file_path):
"""Load candidate bond changes predicted by a WLN for reaction center prediction.
Parameters
----------
file_path : str
Path to a file of candidate bond changes for each reaction.
Returns
-------
all_candidate_bond_changes : list of list
``all_candidate_bond_changes[i]`` gives a list of tuples, which are candidate
bond changes for a reaction.
"""
print('Stage 2/2: loading candidate bond changes...')
with open(file_path, 'r') as f:
lines = f.readlines()
all_candidate_bond_changes = []
for li in tqdm(lines):
all_candidate_bond_changes.append(
load_candidate_bond_changes_for_one_reaction(li))
return all_candidate_bond_changes
def ignore_large(self, ignore=True):
"""Whether to ignore reactions where reactants contain too many atoms.
Parameters
----------
ignore : bool
If ``ignore``, reactions where reactants contain too many atoms will be ignored.
"""
self.ignore_large_samples = ignore
def __len__(self):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
if self.ignore_large_samples:
return len(self.ids_for_small_samples)
else:
return len(self.reactant_mols)
def __getitem__(self, item):
"""Get the i-th datapoint.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
list of B + 1 DGLGraph
The first entry in the list is the DGLGraph for the reactants and the rest are
DGLGraphs for candidate products. Each DGLGraph has edge features in edata['he'] and
node features in ndata['hv'].
candidate_scores : float32 tensor of shape (B, 1)
The sum of scores for bond changes in each combo, where B is the number of combos.
labels : int64 tensor of shape (1, 1), optional
Index for the true candidate product, which is always 0 with pre-processing. This is
returned only when we are not in the training mode.
valid_candidate_combos : list, optional
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction. Each tuple is of form (atom1, atom2,
change_type, score). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants
real_bond_changes : list of tuples
Ground truth bond changes in a reaction. Each tuple is of form (atom1, atom2,
change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product
"""
if self.ignore_large_samples:
item = self.ids_for_small_samples[item]
raw_candidate_bond_changes = self.candidate_bond_changes[item]
real_bond_changes = self.real_bond_changes[item]
reactant_mol = self.reactant_mols[item]
product_mol = self.product_mols[item]
# Get valid candidate products, candidate bond changes considered and reactant info
valid_candidate_combos, candidate_bond_changes, reactant_info = \
pre_process_one_reaction(
(raw_candidate_bond_changes, real_bond_changes,
reactant_mol, product_mol),
self.num_candidate_bond_changes, self.max_num_changes_per_reaction,
self.max_num_change_combos_per_reaction, self.mode)
# Construct DGLGraphs and featurize their edges
g_list = construct_graphs_rank(
(reactant_mol, valid_candidate_combos,
candidate_bond_changes, reactant_info),
self.edge_featurizer)
# Get node features and candidate scores
node_feats, candidate_scores = featurize_nodes_and_compute_combo_scores(
self.node_featurizer, reactant_mol, valid_candidate_combos)
for g in g_list:
g.ndata['hv'] = node_feats
if self.mode == 'train':
labels = torch.zeros(1, 1).long()
return g_list, candidate_scores, labels
else:
reactant_mol = self.reactant_mols[item]
real_bond_changes = self.real_bond_changes[item]
product_mol = self.product_mols[item]
return g_list, candidate_scores, valid_candidate_combos, \
reactant_mol, real_bond_changes, product_mol
class USPTORank(WLNRankDataset):
"""USPTO dataset for ranking candidate products.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def __init__(self,
subset,
candidate_bond_path,
size_cutoff=100,
max_num_changes_per_reaction=5,
num_candidate_bond_changes=16,
max_num_change_combos_per_reaction=150,
num_processes=1):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO for product candidate ranking.'.format(subset))
self._subset = subset
if subset == 'val':
mode = 'val'
subset = 'valid'
else:
mode = subset
self._url = 'dataset/uspto.zip'
data_path = get_download_dir() + '/uspto.zip'
extracted_data_path = get_download_dir() + '/uspto'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
super(USPTORank, self).__init__(
raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
candidate_bond_path=candidate_bond_path,
mode=mode,
size_cutoff=size_cutoff,
max_num_changes_per_reaction=max_num_changes_per_reaction,
num_candidate_bond_changes=num_candidate_bond_changes,
max_num_change_combos_per_reaction=max_num_change_combos_per_reaction,
num_processes=num_processes)
@property
def subset(self):
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return self._subset
"""Information for the library."""
# current version
__version__ = '0.2.2'
"""Model architectures and components at different levels."""
from .gnn import *
from .readout import *
from .model_zoo import *
from .pretrain import *
"""Graph neural networks for updating node representations."""
from .attentivefp import *
from .gat import *
from .gcn import *
from .mgcn import *
from .mpnn import *
from .schnet import *
from .wln import *
from .weave import *
from .gin import *
"""AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import edge_softmax
__all__ = ['AttentiveFPGNN']
# pylint: disable=W0221, C0103, E1101
class AttentiveGRU1(nn.Module):
"""Update node features with attention and GRU.
This will be used for incorporating the information of edge features
into node features for message passing.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Previous edge features.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
"""Update node features with attention and GRU.
This will be used in GNN layers for updating node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
"""Generate context for each node by message passing at the beginning.
This layer incorporates the information of edge features into node
representations so that message passing needs to be only performed over
node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he1'`` to updated edge features.
"""
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he2'`` to updated edge features.
"""
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
"""Incorporate edge features and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
"""GNNLayer for updating node features.
This layer performs message passing over node representations and update them.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the graph representations to be computed.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
def apply_edges(self, edges):
"""Edge feature generation.
Generate edge features by concatenating the features of the destination
and source nodes.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping ``'he'`` to the generated edge features.
"""
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
"""Perform message passing and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.attentive_gru(g, logits, node_feats)
class AttentiveFPGNN(nn.Module):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class performs message passing in AttentiveFP and returns the updated node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
graph_feat_size : int
Size for the graph representations to be computed. Default to 200.
dropout : float
The probability for performing dropout. Default to 0.
"""
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers=2,
graph_feat_size=200,
dropout=0.):
super(AttentiveFPGNN, self).__init__()
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for _ in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, graph_feat_size)
Updated node representations.
"""
node_feats = self.init_context(g, node_feats, edge_feats)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
return node_feats
"""Graph Attention Networks"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
__all__ = ['GAT']
# pylint: disable=W0221
class GATLayer(nn.Module):
r"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
out_feats : int
Number of output node features
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, which is the slope for negative values.
Default to 0.2.
residual : bool
Whether to perform skip connection, default to True.
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all-head results or 'mean' for averaging
all head results.
activation : activation function or None
Activation function applied to the aggregated multi-head results, default to None.
"""
def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
alpha=0.2, residual=True, agg_mode='flatten', activation=None):
super(GATLayer, self).__init__()
self.gat_conv = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
feat_drop=feat_drop, attn_drop=attn_drop,
negative_slope=alpha, residual=residual)
assert agg_mode in ['flatten', 'mean']
self.agg_mode = agg_mode
self.activation = activation
def forward(self, bg, feats):
"""Update node representations
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
out_feats in initialization if self.agg_mode == 'mean' and
out_feats * num_heads in initialization otherwise.
"""
feats = self.gat_conv(bg, feats)
if self.agg_mode == 'flatten':
feats = feats.flatten(1)
else:
feats = feats.mean(1)
if self.activation is not None:
feats = self.activation(feats)
return feats
class GAT(nn.Module):
r"""GAT from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
all-head results for each GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, no activation is applied for each GAT layer.
"""
def __init__(self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None,
attn_drops=None, alphas=None, residuals=None, agg_modes=None, activations=None):
super(GAT, self).__init__()
if hidden_feats is None:
hidden_feats = [32, 32]
n_layers = len(hidden_feats)
if num_heads is None:
num_heads = [4 for _ in range(n_layers)]
if feat_drops is None:
feat_drops = [0. for _ in range(n_layers)]
if attn_drops is None:
attn_drops = [0. for _ in range(n_layers)]
if alphas is None:
alphas = [0.2 for _ in range(n_layers)]
if residuals is None:
residuals = [True for _ in range(n_layers)]
if agg_modes is None:
agg_modes = ['flatten' for _ in range(n_layers - 1)]
agg_modes.append('mean')
if activations is None:
activations = [F.elu for _ in range(n_layers - 1)]
activations.append(None)
lengths = [len(hidden_feats), len(num_heads), len(feat_drops), len(attn_drops),
len(alphas), len(residuals), len(agg_modes), len(activations)]
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, num_heads, ' \
'feat_drops, attn_drops, alphas, residuals, ' \
'agg_modes and activations to be the same, ' \
'got {}'.format(lengths)
self.hidden_feats = hidden_feats
self.num_heads = num_heads
self.agg_modes = agg_modes
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(GATLayer(in_feats, hidden_feats[i], num_heads[i],
feat_drops[i], attn_drops[i], alphas[i],
residuals[i], agg_modes[i], activations[i]))
if agg_modes[i] == 'flatten':
in_feats = hidden_feats[i] * num_heads[i]
else:
in_feats = hidden_feats[i]
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] if agg_modes[-1] == 'mean' and
hidden_sizes[-1] * num_heads[-1] otherwise.
"""
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats
"""Graph Convolutional Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
__all__ = ['GCN']
# pylint: disable=W0221, C0103
class GCNLayer(nn.Module):
r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
out_feats : int
Number of output node features.
activation : activation function
Default to be None.
residual : bool
Whether to use residual connection, default to be True.
batchnorm : bool
Whether to use batch normalization on the output,
default to be True.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, activation=None,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm='none', activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output node feature size, which must match out_feats in initialization
"""
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GCN(nn.Module):
r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
"""
def __init__(self, in_feats, hidden_feats=None, activation=None, residual=None,
batchnorm=None, dropout=None):
super(GCN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64]
n_layers = len(hidden_feats)
if activation is None:
activation = [F.relu for _ in range(n_layers)]
if residual is None:
residual = [True for _ in range(n_layers)]
if batchnorm is None:
batchnorm = [True for _ in range(n_layers)]
if dropout is None:
dropout = [0. for _ in range(n_layers)]
lengths = [len(hidden_feats), len(activation),
len(residual), len(batchnorm), len(dropout)]
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, activation, ' \
'residual, batchnorm and dropout to be the same, ' \
'got {}'.format(lengths)
self.hidden_feats = hidden_feats
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], activation[i],
residual[i], batchnorm[i], dropout[i]))
in_feats = hidden_feats[i]
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] in initialization.
"""
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats
"""Graph Isomorphism Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['GIN']
# pylint: disable=W0221, C0103
class GINLayer(nn.Module):
r"""Single Layer GIN from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
Parameters
----------
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
emb_dim : int
The size of each embedding vector.
batch_norm : bool
Whether to apply batch normalization to the output of message passing.
Default to True.
activation : None or callable
Activation function to apply to the output node representations.
Default to None.
"""
def __init__(self, num_edge_emb_list, emb_dim, batch_norm=True, activation=None):
super(GINLayer, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
nn.ReLU(),
nn.Linear(2 * emb_dim, emb_dim)
)
self.edge_embeddings = nn.ModuleList()
for num_emb in num_edge_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.edge_embeddings.append(emb_module)
if batch_norm:
self.bn = nn.BatchNorm1d(emb_dim)
else:
self.bn = None
self.activation = activation
def forward(self, g, node_feats, categorical_edge_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : FloatTensor of shape (N, emb_dim)
* Input node features
* N is the total number of nodes in the batch of graphs
* emb_dim is the input node feature size, which must match emb_dim in initialization
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
* E is the total number of edges in the batch of graphs
Returns
-------
node_feats : float32 tensor of shape (N, emb_dim)
Output node representations
"""
edge_embeds = []
for i, feats in enumerate(categorical_edge_feats):
edge_embeds.append(self.edge_embeddings[i](feats))
edge_embeds = torch.stack(edge_embeds, dim=0).sum(0)
g = g.local_var()
g.ndata['feat'] = node_feats
g.edata['feat'] = edge_embeds
g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat'))
node_feats = self.mlp(g.ndata.pop('feat'))
if self.bn is not None:
node_feats = self.bn(node_feats)
if self.activation is not None:
node_feats = self.activation(node_feats)
return node_feats
class GIN(nn.Module):
r"""Graph Isomorphism Network from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
This module is for updating node representations only.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``.
Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5
"""
def __init__(self, num_node_emb_list, num_edge_emb_list,
num_layers=5, emb_dim=300, JK='last', dropout=0.5):
super(GIN, self).__init__()
self.num_layers = num_layers
self.JK = JK
self.dropout = nn.Dropout(dropout)
if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))
self.node_embeddings = nn.ModuleList()
for num_emb in num_node_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.node_embeddings.append(emb_module)
self.gnn_layers = nn.ModuleList()
for layer in range(num_layers):
if layer == num_layers - 1:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim))
else:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim, activation=F.relu))
def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Update node representations
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(self.node_embeddings)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
final_node_feats : float32 tensor of shape (N, M)
Output node representations, N for the number of nodes and
M for output size. In particular, M will be emb_dim * (num_layers + 1)
if self.JK == 'concat' and emb_dim otherwise.
"""
node_embeds = []
for i, feats in enumerate(categorical_node_feats):
node_embeds.append(self.node_embeddings[i](feats))
node_embeds = torch.stack(node_embeds, dim=0).sum(0)
all_layer_node_feats = [node_embeds]
for layer in range(self.num_layers):
node_feats = self.gnn_layers[layer](g, all_layer_node_feats[layer],
categorical_edge_feats)
node_feats = self.dropout(node_feats)
all_layer_node_feats.append(node_feats)
if self.JK == 'concat':
final_node_feats = torch.cat(all_layer_node_feats, dim=1)
elif self.JK == 'last':
final_node_feats = all_layer_node_feats[-1]
elif self.JK == 'max':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.max(torch.cat(all_layer_node_feats, dim=0), dim=0)[0]
elif self.JK == 'sum':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.sum(torch.cat(all_layer_node_feats, dim=0), dim=0)
else:
return ValueError("Expect self.JK to be 'concat', 'last', "
"'max' or 'sum', got {}".format(self.JK))
return final_node_feats
"""MGCN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
from .schnet import RBFExpansion
__all__ = ['MGCNGNN']
# pylint: disable=W0221, E1101
class EdgeEmbedding(nn.Module):
"""Module for embedding edges.
Edges whose end nodes have the same combination of types
share the same initial embedding.
Parameters
----------
num_types : int
Number of edge types to embed.
edge_feats : int
Size for the edge representations to learn.
"""
def __init__(self, num_types, edge_feats):
super(EdgeEmbedding, self).__init__()
self.embed = nn.Embedding(num_types, edge_feats)
def get_edge_types(self, edges):
"""Generates edge types.
The edge type is based on the type of the source and destination nodes.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map each pair of node types to a unique number, we use an unordered pairing function.
See more details in this discussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that the number of edge types should be larger than the square of the maximum node
type in the dataset.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping 'type' to the computed edge types.
"""
node_type1 = edges.src['type']
node_type2 = edges.dst['type']
return {
'type': node_type1 * node_type2 + \
(torch.abs(node_type1 - node_type2) - 1) ** 2 / 4
}
def forward(self, g, node_types):
"""Embeds edge types.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
Returns
-------
float32 tensor of shape (E, edge_feats)
Edge representations.
"""
g = g.local_var()
g.ndata['type'] = node_types
g.apply_edges(self.get_edge_types)
return self.embed(g.edata['type'])
class VEConv(nn.Module):
"""Vertex-Edge Convolution in MGCN
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This layer combines both node and edge features in updating node representations.
Parameters
----------
dist_feats : int
Size for the expanded distances.
feats : int
Size for the input and output node and edge representations.
update_edge : bool
Whether to update edge representations. Default to True.
"""
def __init__(self, dist_feats, feats, update_edge=True):
super(VEConv, self).__init__()
self.update_dists = nn.Sequential(
nn.Linear(dist_feats, feats),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(feats, feats)
)
if update_edge:
self.update_edge_feats = nn.Linear(feats, feats)
else:
self.update_edge_feats = None
def forward(self, g, node_feats, edge_feats, expanded_dists):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features.
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Edge representations, updated if ``update_edge == True`` in initialization.
"""
expanded_dists = self.update_dists(expanded_dists)
if self.update_edge_feats is not None:
edge_feats = self.update_edge_feats(edge_feats)
g = g.local_var()
g.ndata.update({'hv': node_feats})
g.edata.update({'dist': expanded_dists, 'he': edge_feats})
g.update_all(message_func=[fn.u_mul_e('hv', 'dist', 'm_0'),
fn.copy_e('he', 'm_1')],
reduce_func=[fn.sum('m_0', 'hv_0'),
fn.sum('m_1', 'hv_1')])
node_feats = g.ndata.pop('hv_0') + g.ndata.pop('hv_1')
return node_feats, edge_feats
class MultiLevelInteraction(nn.Module):
"""Building block for MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__. This layer combines node features,
edge features and expanded distances in message passing and updates node and edge
representations.
Parameters
----------
feats : int
Size for the input and output node and edge representations.
dist_feats : int
Size for the expanded distances.
"""
def __init__(self, feats, dist_feats):
super(MultiLevelInteraction, self).__init__()
self.project_in_node_feats = nn.Linear(feats, feats)
self.conv = VEConv(dist_feats, feats)
self.project_out_node_feats = nn.Sequential(
nn.Linear(feats, feats),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(feats, feats)
)
self.project_edge_feats = nn.Sequential(
nn.Linear(feats, feats),
nn.Softplus(beta=0.5, threshold=14)
)
def forward(self, g, node_feats, edge_feats, expanded_dists):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Updated edge representations.
"""
new_node_feats = self.project_in_node_feats(node_feats)
new_node_feats, edge_feats = self.conv(g, new_node_feats, edge_feats, expanded_dists)
new_node_feats = self.project_out_node_feats(new_node_feats)
node_feats = node_feats + new_node_feats
edge_feats = self.project_edge_feats(edge_feats)
return node_feats, edge_feats
class MGCNGNN(nn.Module):
"""MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This class performs message passing in MGCN and returns the updated node representations.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def __init__(self, feats=128, n_layers=3, num_node_types=100,
num_edge_types=3000, cutoff=30., gap=0.1):
super(MGCNGNN, self).__init__()
self.node_embed = nn.Embedding(num_node_types, feats)
self.edge_embed = EdgeEmbedding(num_edge_types, feats)
self.rbf = RBFExpansion(high=cutoff, gap=gap)
self.gnn_layers = nn.ModuleList()
for _ in range(n_layers):
self.gnn_layers.append(MultiLevelInteraction(feats, len(self.rbf.centers)))
def forward(self, g, node_types, edge_dists):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (V, feats * (n_layers + 1))
Output node representations.
"""
node_feats = self.node_embed(node_types)
edge_feats = self.edge_embed(g, node_types)
expanded_dists = self.rbf(edge_dists)
all_layer_node_feats = [node_feats]
for gnn in self.gnn_layers:
node_feats, edge_feats = gnn(g, node_feats, edge_feats, expanded_dists)
all_layer_node_feats.append(node_feats)
return torch.cat(all_layer_node_feats, dim=1)
"""MPNN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv
__all__ = ['MPNNGNN']
# pylint: disable=W0221
class MPNNGNN(nn.Module):
"""MPNN.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
This class performs message passing in MPNN and returns the updated node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_in_feats : int
Size for the input edge features. Default to 128.
edge_hidden_feats : int
Size for the hidden edge representations.
num_step_message_passing : int
Number of message passing steps. Default to 6.
"""
def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
edge_hidden_feats=128, num_step_message_passing=6):
super(MPNNGNN, self).__init__()
self.project_node_feats = nn.Sequential(
nn.Linear(node_in_feats, node_out_feats),
nn.ReLU()
)
self.num_step_message_passing = num_step_message_passing
edge_network = nn.Sequential(
nn.Linear(edge_in_feats, edge_hidden_feats),
nn.ReLU(),
nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
)
self.gnn_layer = NNConv(
in_feats=node_out_feats,
out_feats=node_out_feats,
edge_func=edge_network,
aggregator_type='sum'
)
self.gru = nn.GRU(node_out_feats, node_out_feats)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
Returns
-------
node_feats : float32 tensor of shape (V, node_out_feats)
Output node representations.
"""
node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats)
for _ in range(self.num_step_message_passing):
node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
node_feats = node_feats.squeeze(0)
return node_feats
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621, W0221, E1102, E1101
"""SchNet"""
import numpy as np
import torch
import torch.nn as nn
from dgl.nn.pytorch import CFConv
__all__ = ['SchNetGNN']
class RBFExpansion(nn.Module):
r"""Expand distances between nodes by radial basis functions.
.. math::
\exp(- \gamma * ||d - \mu||^2)
where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
the distances. We use multiple centers evenly distributed in the range of
:math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
being :math:`gap`.
The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
Choosing fewer centers corresponds to reducing the resolution of the filter.
Parameters
----------
low : float
Smallest center. Default to 0.
high : float
Largest center. Default to 30.
gap : float
Difference between two adjacent centers. :math:`\gamma` will be computed as the
reciprocal of gap. Default to 0.1.
"""
def __init__(self, low=0., high=30., gap=0.1):
super(RBFExpansion, self).__init__()
num_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, num_centers)
self.centers = nn.Parameter(torch.tensor(centers).float(), requires_grad=False)
self.gamma = 1 / gap
def forward(self, edge_dists):
"""Expand distances.
Parameters
----------
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (E, len(self.centers))
Expanded distances.
"""
radial = edge_dists - self.centers
coef = - self.gamma
return torch.exp(coef * (radial ** 2))
class Interaction(nn.Module):
"""Building block for SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This layer combines node and edge features in message passing and updates node
representations.
Parameters
----------
node_feats : int
Size for the input and output node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for hidden representations.
"""
def __init__(self, node_feats, edge_in_feats, hidden_feats):
super(Interaction, self).__init__()
self.conv = CFConv(node_feats, edge_in_feats, hidden_feats, node_feats)
self.project_out = nn.Linear(node_feats, node_feats)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.conv(g, node_feats, edge_feats)
return self.project_out(node_feats)
class SchNetGNN(nn.Module):
"""SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This class performs message passing in SchNet and returns the updated node representations.
Parameters
----------
node_feats : int
Size for node representations to learn. Default to 64.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
layer. ``len(hidden_feats)`` equals the number of interaction layers.
Default to ``[64, 64, 64]``.
num_node_types : int
Number of node types to embed. Default to 100.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def __init__(self, node_feats=64, hidden_feats=None, num_node_types=100, cutoff=30., gap=0.1):
super(SchNetGNN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64, 64]
self.embed = nn.Embedding(num_node_types, node_feats)
self.rbf = RBFExpansion(high=cutoff, gap=gap)
n_layers = len(hidden_feats)
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(
Interaction(node_feats, len(self.rbf.centers), hidden_feats[i]))
def forward(self, g, node_types, edge_dists):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.embed(node_types)
expanded_dists = self.rbf(edge_dists)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats, expanded_dists)
return node_feats
"""Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['WeaveGNN']
# pylint: disable=W0221, E1101
class WeaveLayer(nn.Module):
r"""Single Weave layer from `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_node_hidden_feats : int
Size for the hidden node representations in updating node representations.
Default to 50.
edge_node_hidden_feats : int
Size for the hidden edge representations in updating node representations.
Default to 50.
node_out_feats : int
Size for the output node representations. Default to 50.
node_edge_hidden_feats : int
Size for the hidden node representations in updating edge representations.
Default to 50.
edge_edge_hidden_feats : int
Size for the hidden edge representations in updating edge representations.
Default to 50.
edge_out_feats : int
Size for the output edge representations. Default to 50.
activation : callable
Activation function to apply. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_node_hidden_feats=50,
edge_node_hidden_feats=50,
node_out_feats=50,
node_edge_hidden_feats=50,
edge_edge_hidden_feats=50,
edge_out_feats=50,
activation=F.relu):
super(WeaveLayer, self).__init__()
self.activation = activation
# Layers for updating node representations
self.node_to_node = nn.Linear(node_in_feats, node_node_hidden_feats)
self.edge_to_node = nn.Linear(edge_in_feats, edge_node_hidden_feats)
self.update_node = nn.Linear(
node_node_hidden_feats + edge_node_hidden_feats, node_out_feats)
# Layers for updating edge representations
self.left_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.right_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.edge_to_edge = nn.Linear(edge_in_feats, edge_edge_hidden_feats)
self.update_edge = nn.Linear(
2 * node_edge_hidden_feats + edge_edge_hidden_feats, edge_out_feats)
def forward(self, g, node_feats, edge_feats, node_only=False):
r"""Update node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to update node representations only. If False, edge representations
will be updated as well. Default to False.
Returns
-------
new_node_feats : float32 tensor of shape (V, node_out_feats)
Updated node representations.
new_edge_feats : float32 tensor of shape (E, edge_out_feats)
Updated edge representations.
"""
g = g.local_var()
# Update node features
node_node_feats = self.activation(self.node_to_node(node_feats))
g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
edge_node_feats = g.ndata.pop('e2n')
new_node_feats = self.activation(self.update_node(
torch.cat([node_node_feats, edge_node_feats], dim=1)))
if node_only:
return new_node_feats
# Update edge features
g.ndata['left_hv'] = self.left_node_to_edge(node_feats)
g.ndata['right_hv'] = self.right_node_to_edge(node_feats)
g.apply_edges(fn.u_add_v('left_hv', 'right_hv', 'first'))
g.apply_edges(fn.u_add_v('right_hv', 'left_hv', 'second'))
first_edge_feats = self.activation(g.edata.pop('first'))
second_edge_feats = self.activation(g.edata.pop('second'))
third_edge_feats = self.activation(self.edge_to_edge(edge_feats))
new_edge_feats = self.activation(self.update_edge(
torch.cat([first_edge_feats, second_edge_feats, third_edge_feats], dim=1)))
return new_node_feats, new_edge_feats
class WeaveGNN(nn.Module):
r"""The component of Weave for updating node and edge representations.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_layers : int
Number of Weave layers to use, which is equivalent to the times of message passing.
Default to 2.
hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
activation : callable
Activation function to be used. It cannot be None. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
num_layers=2,
hidden_feats=50,
activation=F.relu):
super(WeaveGNN, self).__init__()
self.gnn_layers = nn.ModuleList()
for i in range(num_layers):
if i == 0:
self.gnn_layers.append(WeaveLayer(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
else:
self.gnn_layers.append(WeaveLayer(node_in_feats=hidden_feats,
edge_in_feats=hidden_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
def forward(self, g, node_feats, edge_feats, node_only=True):
"""Updates node representations (and edge representations).
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to return updated node representations only or to return both
node and edge representations. Default to True.
Returns
-------
float32 tensor of shape (V, gnn_hidden_feats)
Updated node representations.
float32 tensor of shape (E, gnn_hidden_feats), optional
This is returned only when ``node_only==False``. Updated edge representations.
"""
for i in range(len(self.gnn_layers) - 1):
node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
return self.gnn_layers[-1](g, node_feats, edge_feats, node_only)
"""WLN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
__all__ = ['WLN']
class WLNLinear(nn.Module):
r"""Linear layer for WLN
Let stddev be
.. math::
\min(\frac{1.0}{\sqrt{in_feats}}, 0.1)
The weight of the linear layer is initialized from a normal distribution
with mean 0 and std as specified in stddev.
Parameters
----------
in_feats : int
Size for the input.
out_feats : int
Size for the output.
bias : bool
Whether bias will be added to the output. Default to True.
"""
def __init__(self, in_feats, out_feats, bias=True):
super(WLNLinear, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.weight = Parameter(torch.Tensor(out_feats, in_feats))
if bias:
self.bias = Parameter(torch.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Initialize model parameters."""
stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
nn.init.normal_(self.weight, std=stddev)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
def forward(self, feats):
"""Applies the layer.
Parameters
----------
feats : float32 tensor of shape (N, *, in_feats)
N for the number of samples, * for any additional dimensions.
Returns
-------
float32 tensor of shape (N, *, out_feats)
Result of the layer.
"""
return F.linear(feats, self.weight, self.bias)
def extra_repr(self):
"""Return a description of the layer."""
return 'in_feats={}, out_feats={}, bias={}'.format(
self.in_feats, self.out_feats, self.bias is not None
)
class WLN(nn.Module):
"""Weisfeiler-Lehman Network (WLN)
WLN is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
This class performs message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
project_in_feats : bool
Whether to project input node features. If this is False, we expect node_in_feats
to be the same as node_out_feats. Default to True.
set_comparison : bool
Whether to perform final node representation update mimicking
set comparison. Default to True.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=300,
n_layers=3,
project_in_feats=True,
set_comparison=True):
super(WLN, self).__init__()
self.n_layers = n_layers
self.project_in_feats = project_in_feats
if project_in_feats:
self.project_node_in_feats = nn.Sequential(
WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
else:
assert node_in_feats == node_out_feats, \
'Expect input node features to have the same size as that of output ' \
'node features, got {:d} and {:d}'.format(node_in_feats, node_out_feats)
self.project_concatenated_messages = nn.Sequential(
WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
nn.ReLU()
)
self.get_new_node_feats = nn.Sequential(
WLNLinear(2 * node_out_feats, node_out_feats),
nn.ReLU()
)
self.set_comparison = set_comparison
if set_comparison:
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_out_feats)
Updated node representations.
"""
if self.project_in_feats:
node_feats = self.project_node_in_feats(node_feats)
for _ in range(self.n_layers):
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(fn.copy_src('hv', 'he_src'))
concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
node_feats = self.get_new_node_feats(
torch.cat([node_feats, g.ndata['hv_new']], dim=1))
if not self.set_comparison:
return node_feats
else:
g = g.local_var()
g.ndata['hv'] = self.project_node_messages(node_feats)
g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
"""Collection of model architectures"""
from .jtnn import *
from .dgmg import *
from .attentivefp_predictor import *
from .gat_predictor import *
from .gcn_predictor import *
from .mlp_predictor import *
from .schnet_predictor import *
from .mgcn_predictor import *
from .mpnn_predictor import *
from .acnn import *
from .wln_reaction_center import *
from .wln_reaction_ranking import *
from .weave_predictor import *
from .gin_predictor import *
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123, W0221, E1101, R1721
import itertools
import numpy as np
import torch
import torch.nn as nn
from dgl import BatchedDGLHeteroGraph
from dgl.nn.pytorch import AtomicConv
__all__ = ['ACNN']
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape = tensor.shape
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ACNNPredictor(nn.Module):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
dropouts, features_to_use, num_tasks):
super(ACNNPredictor, self).__init__()
if type(features_to_use) != type(None):
in_size *= len(features_to_use)
modules = []
for i, h in enumerate(hidden_sizes):
linear_layer = nn.Linear(in_size, h)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
modules.append(linear_layer)
modules.append(nn.ReLU())
modules.append(nn.Dropout(dropouts[i]))
in_size = h
linear_layer = nn.Linear(in_size, num_tasks)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
modules.append(linear_layer)
self.project = nn.Sequential(*modules)
def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats = self.project(ligand_conv_out) # (V1, O)
protein_feats = self.project(protein_conv_out) # (V2, O)
complex_feats = self.project(complex_conv_out) # (V1+V2, O)
ligand_energy = ligand_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
protein_energy = protein_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_energy = complex_ligand_energy + complex_protein_energy
return complex_energy - (ligand_energy + protein_energy)
class ACNN(nn.Module):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
The prediction proceeds as follows:
1. Perform message passing to update atom representations for the
ligand, protein and protein-ligand complex.
2. Predict the energy of atoms from their representations with an MLP.
3. Take the sum of predicted energy of atoms within each molecule for
predicted energy of the ligand, protein and protein-ligand complex.
4. Make the final prediction by subtracting the predicted ligand and protein
energy from the predicted complex energy.
Parameters
----------
hidden_sizes : list of int
``hidden_sizes[i]`` gives the size of hidden representations in the i-th
hidden layer of the MLP. By Default, ``[32, 32, 16]`` will be used.
weight_init_stddevs : list of float
``weight_init_stddevs[i]`` gives the std to initialize parameters in the
i-th layer of the MLP. Note that ``len(weight_init_stddevs) == len(hidden_sizes) + 1``
due to the output layer. By default, we use ``1 / sqrt(hidden_sizes[i])`` for hidden
layers and 0.01 for the output layer.
dropouts : list of float
``dropouts[i]`` gives the dropout in the i-th hidden layer of the MLP. By default,
no dropout is used.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. If None, we use same parameters
for all atoms regardless of their type. Default to None.
radial : list
The list consists of 3 sublists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. By default,
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks. Default to 1.
"""
def __init__(self, hidden_sizes=None, weight_init_stddevs=None, dropouts=None,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
if hidden_sizes is None:
hidden_sizes = [32, 32, 16]
if weight_init_stddevs is None:
weight_init_stddevs = [1. / float(np.sqrt(hidden_sizes[i]))
for i in range(len(hidden_sizes))]
weight_init_stddevs.append(0.01)
if dropouts is None:
dropouts = [0. for _ in range(len(hidden_sizes))]
if radial is None:
radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params = [x for x in itertools.product(*radial)]
radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
interaction_cutoffs = radial_params[:, 0]
rbf_kernel_means = radial_params[:, 1]
rbf_kernel_scaling = radial_params[:, 2]
self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
weight_init_stddevs, dropouts, features_to_use, num_tasks)
def forward(self, graph):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features. For a batch of
protein-ligand pairs, we assume zero padding is performed so that the
number of ligand and protein atoms is the same in all pairs.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
assert ligand_graph_node_feats.shape[-1] == 1
ligand_graph_distances = ligand_graph.edata['distance']
ligand_conv_out = self.ligand_conv(ligand_graph,
ligand_graph_node_feats,
ligand_graph_distances)
protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
protein_graph_node_feats = protein_graph.ndata['atomic_number']
assert protein_graph_node_feats.shape[-1] == 1
protein_graph_distances = protein_graph.edata['distance']
protein_conv_out = self.protein_conv(protein_graph,
protein_graph_node_feats,
protein_graph_distances)
complex_graph = graph[:, 'complex', :]
complex_graph_node_feats = complex_graph.ndata['atomic_number']
assert complex_graph_node_feats.shape[-1] == 1
complex_graph_distances = complex_graph.edata['distance']
complex_conv_out = self.complex_conv(complex_graph,
complex_graph_node_feats,
complex_graph_distances)
frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
frag2_node_indices_in_complex = list(set(range(complex_graph.number_of_nodes())) -
set(frag1_node_indices_in_complex.tolist()))
# Hack the case when we are working with a single graph.
if not isinstance(graph, BatchedDGLHeteroGraph):
graph.batch_size = 1
return self.predictor(
graph.batch_size,
frag1_node_indices_in_complex,
frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment