Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
4452 deletions
+0
-4452
apps/life_sci/python/dgllife/data/alchemy.py
apps/life_sci/python/dgllife/data/alchemy.py
+0
-295
apps/life_sci/python/dgllife/data/csv_dataset.py
apps/life_sci/python/dgllife/data/csv_dataset.py
+0
-134
apps/life_sci/python/dgllife/data/pdbbind.py
apps/life_sci/python/dgllife/data/pdbbind.py
+0
-284
apps/life_sci/python/dgllife/data/pubchem_aromaticity.py
apps/life_sci/python/dgllife/data/pubchem_aromaticity.py
+0
-49
apps/life_sci/python/dgllife/data/tox21.py
apps/life_sci/python/dgllife/data/tox21.py
+0
-103
apps/life_sci/python/dgllife/data/uspto.py
apps/life_sci/python/dgllife/data/uspto.py
+0
-1564
apps/life_sci/python/dgllife/libinfo.py
apps/life_sci/python/dgllife/libinfo.py
+0
-4
apps/life_sci/python/dgllife/model/__init__.py
apps/life_sci/python/dgllife/model/__init__.py
+0
-5
apps/life_sci/python/dgllife/model/gnn/__init__.py
apps/life_sci/python/dgllife/model/gnn/__init__.py
+0
-10
apps/life_sci/python/dgllife/model/gnn/attentivefp.py
apps/life_sci/python/dgllife/model/gnn/attentivefp.py
+0
-326
apps/life_sci/python/dgllife/model/gnn/gat.py
apps/life_sci/python/dgllife/model/gnn/gat.py
+0
-184
apps/life_sci/python/dgllife/model/gnn/gcn.py
apps/life_sci/python/dgllife/model/gnn/gcn.py
+0
-154
apps/life_sci/python/dgllife/model/gnn/gin.py
apps/life_sci/python/dgllife/model/gnn/gin.py
+0
-200
apps/life_sci/python/dgllife/model/gnn/mgcn.py
apps/life_sci/python/dgllife/model/gnn/mgcn.py
+0
-265
apps/life_sci/python/dgllife/model/gnn/mpnn.py
apps/life_sci/python/dgllife/model/gnn/mpnn.py
+0
-79
apps/life_sci/python/dgllife/model/gnn/schnet.py
apps/life_sci/python/dgllife/model/gnn/schnet.py
+0
-164
apps/life_sci/python/dgllife/model/gnn/weave.py
apps/life_sci/python/dgllife/model/gnn/weave.py
+0
-191
apps/life_sci/python/dgllife/model/gnn/wln.py
apps/life_sci/python/dgllife/model/gnn/wln.py
+0
-171
apps/life_sci/python/dgllife/model/model_zoo/__init__.py
apps/life_sci/python/dgllife/model/model_zoo/__init__.py
+0
-15
apps/life_sci/python/dgllife/model/model_zoo/acnn.py
apps/life_sci/python/dgllife/model/model_zoo/acnn.py
+0
-255
No files found.
apps/life_sci/python/dgllife/data/alchemy.py
deleted
100644 → 0
View file @
94c67203
# -*- coding:utf-8 -*-
"""Tencent Alchemy Dataset https://alchemy.tencent.com/"""
import
numpy
as
np
import
os
import
os.path
as
osp
import
pandas
as
pd
import
pathlib
import
zipfile
from
collections
import
defaultdict
from
dgl
import
backend
as
F
from
dgl.data.utils
import
download
,
get_download_dir
,
_get_dgl_url
,
save_graphs
,
load_graphs
from
rdkit
import
Chem
from
rdkit.Chem
import
ChemicalFeatures
from
rdkit
import
RDConfig
from
..utils.mol_to_graph
import
mol_to_complete_graph
from
..utils.featurizers
import
atom_type_one_hot
,
atom_hybridization_one_hot
,
atom_is_aromatic
__all__
=
[
'TencentAlchemyDataset'
]
def
alchemy_nodes
(
mol
):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict
=
defaultdict
(
list
)
is_donor
=
defaultdict
(
int
)
is_acceptor
=
defaultdict
(
int
)
fdef_name
=
osp
.
join
(
RDConfig
.
RDDataDir
,
'BaseFeatures.fdef'
)
mol_featurizer
=
ChemicalFeatures
.
BuildFeatureFactory
(
fdef_name
)
mol_feats
=
mol_featurizer
.
GetFeaturesForMol
(
mol
)
mol_conformers
=
mol
.
GetConformers
()
assert
len
(
mol_conformers
)
==
1
for
i
in
range
(
len
(
mol_feats
)):
if
mol_feats
[
i
].
GetFamily
()
==
'Donor'
:
node_list
=
mol_feats
[
i
].
GetAtomIds
()
for
u
in
node_list
:
is_donor
[
u
]
=
1
elif
mol_feats
[
i
].
GetFamily
()
==
'Acceptor'
:
node_list
=
mol_feats
[
i
].
GetAtomIds
()
for
u
in
node_list
:
is_acceptor
[
u
]
=
1
num_atoms
=
mol
.
GetNumAtoms
()
for
u
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
u
)
atom_type
=
atom
.
GetAtomicNum
()
num_h
=
atom
.
GetTotalNumHs
()
atom_feats_dict
[
'node_type'
].
append
(
atom_type
)
h_u
=
[]
h_u
+=
atom_type_one_hot
(
atom
,
[
'H'
,
'C'
,
'N'
,
'O'
,
'F'
,
'S'
,
'Cl'
])
h_u
.
append
(
atom_type
)
h_u
.
append
(
is_acceptor
[
u
])
h_u
.
append
(
is_donor
[
u
])
h_u
+=
atom_is_aromatic
(
atom
)
h_u
+=
atom_hybridization_one_hot
(
atom
,
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
])
h_u
.
append
(
num_h
)
atom_feats_dict
[
'n_feat'
].
append
(
F
.
tensor
(
np
.
array
(
h_u
).
astype
(
np
.
float32
)))
atom_feats_dict
[
'n_feat'
]
=
F
.
stack
(
atom_feats_dict
[
'n_feat'
],
dim
=
0
)
atom_feats_dict
[
'node_type'
]
=
F
.
tensor
(
np
.
array
(
atom_feats_dict
[
'node_type'
]).
astype
(
np
.
int64
))
return
atom_feats_dict
def
alchemy_edges
(
mol
,
self_loop
=
False
):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict
=
defaultdict
(
list
)
mol_conformers
=
mol
.
GetConformers
()
assert
len
(
mol_conformers
)
==
1
geom
=
mol_conformers
[
0
].
GetPositions
()
num_atoms
=
mol
.
GetNumAtoms
()
for
u
in
range
(
num_atoms
):
for
v
in
range
(
num_atoms
):
if
u
==
v
and
not
self_loop
:
continue
e_uv
=
mol
.
GetBondBetweenAtoms
(
u
,
v
)
if
e_uv
is
None
:
bond_type
=
None
else
:
bond_type
=
e_uv
.
GetBondType
()
bond_feats_dict
[
'e_feat'
].
append
([
float
(
bond_type
==
x
)
for
x
in
(
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
,
None
)
])
bond_feats_dict
[
'distance'
].
append
(
np
.
linalg
.
norm
(
geom
[
u
]
-
geom
[
v
]))
bond_feats_dict
[
'e_feat'
]
=
F
.
tensor
(
np
.
array
(
bond_feats_dict
[
'e_feat'
]).
astype
(
np
.
float32
))
bond_feats_dict
[
'distance'
]
=
F
.
tensor
(
np
.
array
(
bond_feats_dict
[
'distance'
]).
astype
(
np
.
float32
)).
reshape
(
-
1
,
1
)
return
bond_feats_dict
class
TencentAlchemyDataset
(
object
):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgllife.utils.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def
__init__
(
self
,
mode
=
'dev'
,
mol_to_graph
=
mol_to_complete_graph
,
node_featurizer
=
alchemy_nodes
,
edge_featurizer
=
alchemy_edges
,
load
=
True
):
if
mode
==
'test'
:
raise
ValueError
(
'The test mode is not supported before '
'the Alchemy contest finishes.'
)
assert
mode
in
[
'dev'
,
'valid'
,
'test'
],
\
'Expect mode to be dev, valid or test, got {}.'
.
format
(
mode
)
self
.
mode
=
mode
# Construct DGLGraphs from raw data or use the preprocessed data
self
.
load
=
load
file_dir
=
osp
.
join
(
get_download_dir
(),
'Alchemy_data'
)
if
load
:
file_name
=
"{}_processed_dgl"
.
format
(
mode
)
else
:
file_name
=
"{}_single_sdf"
.
format
(
mode
)
self
.
file_dir
=
pathlib
.
Path
(
file_dir
,
file_name
)
self
.
_url
=
'dataset/alchemy/'
self
.
zip_file_path
=
pathlib
.
Path
(
file_dir
,
file_name
+
'.zip'
)
download
(
_get_dgl_url
(
self
.
_url
+
file_name
+
'.zip'
),
path
=
str
(
self
.
zip_file_path
))
if
not
os
.
path
.
exists
(
str
(
self
.
file_dir
)):
archive
=
zipfile
.
ZipFile
(
self
.
zip_file_path
)
archive
.
extractall
(
file_dir
)
archive
.
close
()
self
.
_load
(
mol_to_graph
,
node_featurizer
,
edge_featurizer
)
def
_load
(
self
,
mol_to_graph
,
node_featurizer
,
edge_featurizer
):
if
self
.
load
:
self
.
graphs
,
label_dict
=
load_graphs
(
osp
.
join
(
self
.
file_dir
,
"{}_graphs.bin"
.
format
(
self
.
mode
)))
self
.
labels
=
label_dict
[
'labels'
]
with
open
(
osp
.
join
(
self
.
file_dir
,
"{}_smiles.txt"
.
format
(
self
.
mode
)),
'r'
)
as
f
:
smiles_
=
f
.
readlines
()
self
.
smiles
=
[
s
.
strip
()
for
s
in
smiles_
]
else
:
print
(
'Start preprocessing dataset...'
)
target_file
=
pathlib
.
Path
(
self
.
file_dir
,
"{}_target.csv"
.
format
(
self
.
mode
))
self
.
target
=
pd
.
read_csv
(
target_file
,
index_col
=
0
,
usecols
=
[
'gdb_idx'
,]
+
[
'property_{:d}'
.
format
(
x
)
for
x
in
range
(
12
)])
self
.
target
=
self
.
target
[[
'property_{:d}'
.
format
(
x
)
for
x
in
range
(
12
)]]
self
.
graphs
,
self
.
labels
,
self
.
smiles
=
[],
[],
[]
supp
=
Chem
.
SDMolSupplier
(
osp
.
join
(
self
.
file_dir
,
self
.
mode
+
".sdf"
))
cnt
=
0
dataset_size
=
len
(
self
.
target
)
for
mol
,
label
in
zip
(
supp
,
self
.
target
.
iterrows
()):
cnt
+=
1
print
(
'Processing molecule {:d}/{:d}'
.
format
(
cnt
,
dataset_size
))
graph
=
mol_to_graph
(
mol
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
)
smiles
=
Chem
.
MolToSmiles
(
mol
)
self
.
smiles
.
append
(
smiles
)
self
.
graphs
.
append
(
graph
)
label
=
F
.
tensor
(
np
.
array
(
label
[
1
].
tolist
()).
astype
(
np
.
float32
))
self
.
labels
.
append
(
label
)
save_graphs
(
osp
.
join
(
self
.
file_dir
,
"{}_graphs.bin"
.
format
(
self
.
mode
)),
self
.
graphs
,
labels
=
{
'labels'
:
F
.
stack
(
self
.
labels
,
dim
=
0
)})
with
open
(
osp
.
join
(
self
.
file_dir
,
"{}_smiles.txt"
.
format
(
self
.
mode
)),
'w'
)
as
f
:
for
s
in
self
.
smiles
:
f
.
write
(
s
+
'
\n
'
)
self
.
set_mean_and_std
()
print
(
len
(
self
.
graphs
),
"loaded!"
)
def
__getitem__
(
self
,
item
):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32 and shape (T)
Labels of the datapoint for all tasks.
"""
return
self
.
smiles
[
item
],
self
.
graphs
[
item
],
self
.
labels
[
item
]
def
__len__
(
self
):
"""Size for the dataset.
Returns
-------
int
Size for the dataset.
"""
return
len
(
self
.
graphs
)
def
set_mean_and_std
(
self
,
mean
=
None
,
std
=
None
):
"""Set mean and std or compute from labels for future normalization.
The mean and std can be fetched later with ``self.mean`` and ``self.std``.
Parameters
----------
mean : float32 tensor of shape (T)
Mean of labels for all tasks.
std : float32 tensor of shape (T)
Std of labels for all tasks.
"""
labels
=
np
.
array
([
i
.
numpy
()
for
i
in
self
.
labels
])
if
mean
is
None
:
mean
=
np
.
mean
(
labels
,
axis
=
0
)
if
std
is
None
:
std
=
np
.
std
(
labels
,
axis
=
0
)
self
.
mean
=
mean
self
.
std
=
std
apps/life_sci/python/dgllife/data/csv_dataset.py
deleted
100644 → 0
View file @
94c67203
"""Creating datasets from .csv files for molecular property prediction."""
import
dgl.backend
as
F
import
numpy
as
np
import
os
from
dgl.data.utils
import
save_graphs
,
load_graphs
__all__
=
[
'MoleculeCSVDataset'
]
class
MoleculeCSVDataset
(
object
):
"""MoleculeCSVDataset
This is a general class for loading molecular data from :class:`pandas.DataFrame`.
In data pre-processing, we construct a binary mask indicating the existence of labels.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and some other columns include labels.
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name for smiles in ``df``.
cache_file_path: str
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def
__init__
(
self
,
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
smiles_column
,
cache_file_path
,
task_names
=
None
,
load
=
True
,
log_every
=
1000
):
self
.
df
=
df
self
.
smiles
=
self
.
df
[
smiles_column
].
tolist
()
if
task_names
is
None
:
self
.
task_names
=
self
.
df
.
columns
.
drop
([
smiles_column
]).
tolist
()
else
:
self
.
task_names
=
task_names
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
cache_file_path
=
cache_file_path
self
.
_pre_process
(
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
,
log_every
)
def
_pre_process
(
self
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
,
log_every
):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed.
"""
if
os
.
path
.
exists
(
self
.
cache_file_path
)
and
load
:
# DGLGraphs have been constructed before, reload them
print
(
'Loading previously saved dgl graphs...'
)
self
.
graphs
,
label_dict
=
load_graphs
(
self
.
cache_file_path
)
self
.
labels
=
label_dict
[
'labels'
]
self
.
mask
=
label_dict
[
'mask'
]
else
:
print
(
'Processing dgl graphs from scratch...'
)
self
.
graphs
=
[]
for
i
,
s
in
enumerate
(
self
.
smiles
):
if
(
i
+
1
)
%
log_every
==
0
:
print
(
'Processing molecule {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
self
.
graphs
.
append
(
smiles_to_graph
(
s
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
))
_label_values
=
self
.
df
[
self
.
task_names
].
values
# np.nan_to_num will also turn inf into a very large number
self
.
labels
=
F
.
zerocopy_from_numpy
(
np
.
nan_to_num
(
_label_values
).
astype
(
np
.
float32
))
self
.
mask
=
F
.
zerocopy_from_numpy
((
~
np
.
isnan
(
_label_values
)).
astype
(
np
.
float32
))
save_graphs
(
self
.
cache_file_path
,
self
.
graphs
,
labels
=
{
'labels'
:
self
.
labels
,
'mask'
:
self
.
mask
})
def
__getitem__
(
self
,
item
):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32 and shape (T)
Labels of the datapoint for all tasks
Tensor of dtype float32 and shape (T)
Binary masks indicating the existence of labels for all tasks
"""
return
self
.
smiles
[
item
],
self
.
graphs
[
item
],
self
.
labels
[
item
],
self
.
mask
[
item
]
def
__len__
(
self
):
"""Size for the dataset
Returns
-------
int
Size for the dataset
"""
return
len
(
self
.
smiles
)
apps/life_sci/python/dgllife/data/pdbbind.py
deleted
100644 → 0
View file @
94c67203
"""PDBBind dataset processed by MoleculeNet."""
import
dgl.backend
as
F
import
numpy
as
np
import
os
import
pandas
as
pd
from
dgl.data.utils
import
get_download_dir
,
download
,
_get_dgl_url
,
extract_archive
from
..utils
import
multiprocess_load_molecules
,
ACNN_graph_construction_and_featurization
__all__
=
[
'PDBBind'
]
class
PDBBind
(
object
):
"""PDBbind dataset processed by MoleculeNet.
The description below is mainly based on
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__.
The PDBBind database consists of experimentally measured binding affinities for
bio-molecular complexes `[2] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15163179%5Buid%5D>`__,
`[3] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15943484%5Buid%5D>`__. It provides detailed
3D Cartesian coordinates of both ligands and their target proteins derived from experimental
(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the
protein-ligand binding geometry. The authors of
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__ use the
"refined" and "core" subsets of the database
`[4] <https://www.ncbi.nlm.nih.gov/pubmed/?term=25301850%5Buid%5D>`__, more carefully
processed for data artifacts, as additional benchmarking targets.
References:
* [1] MoleculeNet: a benchmark for molecular machine learning
* [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures
* [3] The PDBbind database: methodologies and updates
* [4] PDB-wide collection of binding data: current status of the PDBbind database
Parameters
----------
subset : str
In MoleculeNet, we can use either the "refined" subset or the "core" subset. We can
retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping ``self.ligand_mols[i]``,
``self.protein_mols[i]``, ``self.ligand_coordinates[i]`` and
``self.protein_coordinates[i]`` to a DGLHeteroGraph.
Default to :func:`dgllife.utils.ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios. Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
def
__init__
(
self
,
subset
,
load_binding_pocket
=
True
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
,
construct_graph_and_featurize
=
ACNN_graph_construction_and_featurization
,
zero_padding
=
True
,
num_processes
=
64
):
self
.
task_names
=
[
'-logKd/Ki'
]
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
_url
=
'dataset/pdbbind_v2015.tar.gz'
root_dir_path
=
get_download_dir
()
data_path
=
root_dir_path
+
'/pdbbind_v2015.tar.gz'
extracted_data_path
=
root_dir_path
+
'/pdbbind_v2015'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
extract_archive
(
data_path
,
extracted_data_path
)
if
subset
==
'core'
:
index_label_file
=
extracted_data_path
+
'/v2015/INDEX_core_data.2013'
elif
subset
==
'refined'
:
index_label_file
=
extracted_data_path
+
'/v2015/INDEX_refined_data.2015'
else
:
raise
ValueError
(
'Expect the subset_choice to be either '
'core or refined, got {}'
.
format
(
subset
))
self
.
_preprocess
(
extracted_data_path
,
index_label_file
,
load_binding_pocket
,
sanitize
,
calc_charges
,
remove_hs
,
use_conformation
,
construct_graph_and_featurize
,
zero_padding
,
num_processes
)
def
_filter_out_invalid
(
self
,
ligands_loaded
,
proteins_loaded
,
use_conformation
):
"""Filter out invalid ligand-protein pairs.
Parameters
----------
ligands_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
proteins_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
use_conformation : bool
Whether we need conformation information (atom coordinates) and filter out molecules
without valid conformation.
"""
num_pairs
=
len
(
proteins_loaded
)
self
.
indices
,
self
.
ligand_mols
,
self
.
protein_mols
=
[],
[],
[]
if
use_conformation
:
self
.
ligand_coordinates
,
self
.
protein_coordinates
=
[],
[]
else
:
# Use None for placeholders.
self
.
ligand_coordinates
=
[
None
for
_
in
range
(
num_pairs
)]
self
.
protein_coordinates
=
[
None
for
_
in
range
(
num_pairs
)]
for
i
in
range
(
num_pairs
):
ligand_mol
,
ligand_coordinates
=
ligands_loaded
[
i
]
protein_mol
,
protein_coordinates
=
proteins_loaded
[
i
]
if
(
not
use_conformation
)
and
all
(
v
is
not
None
for
v
in
[
protein_mol
,
ligand_mol
]):
self
.
indices
.
append
(
i
)
self
.
ligand_mols
.
append
(
ligand_mol
)
self
.
protein_mols
.
append
(
protein_mol
)
elif
all
(
v
is
not
None
for
v
in
[
protein_mol
,
protein_coordinates
,
ligand_mol
,
ligand_coordinates
]):
self
.
indices
.
append
(
i
)
self
.
ligand_mols
.
append
(
ligand_mol
)
self
.
ligand_coordinates
.
append
(
ligand_coordinates
)
self
.
protein_mols
.
append
(
protein_mol
)
self
.
protein_coordinates
.
append
(
protein_coordinates
)
def
_preprocess
(
self
,
root_path
,
index_label_file
,
load_binding_pocket
,
sanitize
,
calc_charges
,
remove_hs
,
use_conformation
,
construct_graph_and_featurize
,
zero_padding
,
num_processes
):
"""Preprocess the dataset.
The pre-processing proceeds as follows:
1. Load the dataset
2. Clean the dataset and filter out invalid pairs
3. Construct graphs
4. Prepare node and edge features
Parameters
----------
root_path : str
Root path for molecule files.
index_label_file : str
Path to the index file for the dataset.
load_binding_pocket : bool
Whether to load binding pockets or full proteins.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system.
"""
contents
=
[]
with
open
(
index_label_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
if
line
[
0
]
!=
"#"
:
splitted_elements
=
line
.
split
()
if
len
(
splitted_elements
)
==
8
:
# Ignore "//"
contents
.
append
(
splitted_elements
[:
5
]
+
splitted_elements
[
6
:])
else
:
print
(
'Incorrect data format.'
)
print
(
splitted_elements
)
self
.
df
=
pd
.
DataFrame
(
contents
,
columns
=
(
'PDB_code'
,
'resolution'
,
'release_year'
,
'-logKd/Ki'
,
'Kd/Ki'
,
'reference'
,
'ligand_name'
))
pdbs
=
self
.
df
[
'PDB_code'
].
tolist
()
self
.
ligand_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_ligand.sdf'
.
format
(
pdb
))
for
pdb
in
pdbs
]
if
load_binding_pocket
:
self
.
protein_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_pocket.pdb'
.
format
(
pdb
))
for
pdb
in
pdbs
]
else
:
self
.
protein_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_protein.pdb'
.
format
(
pdb
))
for
pdb
in
pdbs
]
num_processes
=
min
(
num_processes
,
len
(
pdbs
))
print
(
'Loading ligands...'
)
ligands_loaded
=
multiprocess_load_molecules
(
self
.
ligand_files
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
,
num_processes
=
num_processes
)
print
(
'Loading proteins...'
)
proteins_loaded
=
multiprocess_load_molecules
(
self
.
protein_files
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
,
num_processes
=
num_processes
)
self
.
_filter_out_invalid
(
ligands_loaded
,
proteins_loaded
,
use_conformation
)
self
.
df
=
self
.
df
.
iloc
[
self
.
indices
]
self
.
labels
=
F
.
zerocopy_from_numpy
(
self
.
df
[
self
.
task_names
].
values
.
astype
(
np
.
float32
))
print
(
'Finished cleaning the dataset, '
'got {:d}/{:d} valid pairs'
.
format
(
len
(
self
),
len
(
pdbs
)))
# Prepare zero padding
if
zero_padding
:
max_num_ligand_atoms
=
0
max_num_protein_atoms
=
0
for
i
in
range
(
len
(
self
)):
max_num_ligand_atoms
=
max
(
max_num_ligand_atoms
,
self
.
ligand_mols
[
i
].
GetNumAtoms
())
max_num_protein_atoms
=
max
(
max_num_protein_atoms
,
self
.
protein_mols
[
i
].
GetNumAtoms
())
else
:
max_num_ligand_atoms
=
None
max_num_protein_atoms
=
None
print
(
'Start constructing graphs and featurizing them.'
)
self
.
graphs
=
[]
for
i
in
range
(
len
(
self
)):
print
(
'Constructing and featurizing datapoint {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
self
.
graphs
.
append
(
construct_graph_and_featurize
(
self
.
ligand_mols
[
i
],
self
.
protein_mols
[
i
],
self
.
ligand_coordinates
[
i
],
self
.
protein_coordinates
[
i
],
max_num_ligand_atoms
,
max_num_protein_atoms
))
def
__len__
(
self
):
"""Get the size of the dataset.
Returns
-------
int
Number of valid ligand-protein pairs in the dataset.
"""
return
len
(
self
.
indices
)
def
__getitem__
(
self
,
item
):
"""Get the datapoint associated with the index.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
int
Index for the datapoint.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the ligand molecule.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the protein molecule.
DGLHeteroGraph
Pre-processed DGLHeteroGraph with features extracted.
Float32 tensor
Label for the datapoint.
"""
return
item
,
self
.
ligand_mols
[
item
],
self
.
protein_mols
[
item
],
\
self
.
graphs
[
item
],
self
.
labels
[
item
]
apps/life_sci/python/dgllife/data/pubchem_aromaticity.py
deleted
100644 → 0
View file @
94c67203
"""Dataset for aromaticity prediction"""
import
pandas
as
pd
from
dgl.data.utils
import
get_download_dir
,
download
,
_get_dgl_url
from
.csv_dataset
import
MoleculeCSVDataset
from
..utils.mol_to_graph
import
smiles_to_bigraph
__all__
=
[
'PubChemBioAssayAromaticity'
]
class
PubChemBioAssayAromaticity
(
MoleculeCSVDataset
):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def
__init__
(
self
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
load
=
True
,
log_every
=
1000
):
self
.
_url
=
'dataset/pubchem_bioassay_aromaticity.csv'
data_path
=
get_download_dir
()
+
'/pubchem_bioassay_aromaticity.csv'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
df
=
pd
.
read_csv
(
data_path
)
super
(
PubChemBioAssayAromaticity
,
self
).
__init__
(
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
"cano_smiles"
,
"pubchem_aromaticity_dglgraph.bin"
,
load
=
load
,
log_every
=
log_every
)
apps/life_sci/python/dgllife/data/tox21.py
deleted
100644 → 0
View file @
94c67203
"""The Toxicology in the 21st Century initiative."""
import
dgl.backend
as
F
import
pandas
as
pd
from
dgl.data.utils
import
get_download_dir
,
download
,
_get_dgl_url
from
.csv_dataset
import
MoleculeCSVDataset
from
..utils.mol_to_graph
import
smiles_to_bigraph
__all__
=
[
'Tox21'
]
class
Tox21
(
MoleculeCSVDataset
):
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def
__init__
(
self
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
load
=
True
,
log_every
=
1000
):
self
.
_url
=
'dataset/tox21.csv.gz'
data_path
=
get_download_dir
()
+
'/tox21.csv.gz'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
df
=
pd
.
read_csv
(
data_path
)
self
.
id
=
df
[
'mol_id'
]
df
=
df
.
drop
(
columns
=
[
'mol_id'
])
super
(
Tox21
,
self
).
__init__
(
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
"smiles"
,
"tox21_dglgraph.bin"
,
load
=
load
,
log_every
=
log_every
)
self
.
_weight_balancing
()
def
_weight_balancing
(
self
):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos
=
F
.
sum
(
self
.
labels
,
dim
=
0
)
num_indices
=
F
.
sum
(
self
.
mask
,
dim
=
0
)
self
.
_task_pos_weights
=
(
num_indices
-
num_pos
)
/
num_pos
@
property
def
task_pos_weights
(
self
):
"""Get weights for positive samples on each task
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
Returns
-------
Tensor of dtype float32 and shape (T)
Weight of positive samples on all tasks
"""
return
self
.
_task_pos_weights
apps/life_sci/python/dgllife/data/uspto.py
deleted
100644 → 0
View file @
94c67203
"""USPTO for reaction prediction"""
import
errno
import
numpy
as
np
import
os
import
random
import
torch
from
collections
import
defaultdict
from
copy
import
deepcopy
from
dgl
import
DGLGraph
from
dgl.data.utils
import
get_download_dir
,
download
,
_get_dgl_url
,
extract_archive
,
\
save_graphs
,
load_graphs
from
functools
import
partial
from
itertools
import
combinations
from
multiprocessing
import
Pool
from
rdkit
import
Chem
,
RDLogger
from
rdkit.Chem
import
rdmolops
from
tqdm
import
tqdm
from
..utils.featurizers
import
BaseAtomFeaturizer
,
ConcatFeaturizer
,
one_hot_encoding
,
\
atom_type_one_hot
,
atom_degree_one_hot
,
atom_explicit_valence_one_hot
,
\
atom_implicit_valence_one_hot
,
atom_is_aromatic
,
atom_formal_charge_one_hot
,
\
BaseBondFeaturizer
,
bond_type_one_hot
,
bond_is_conjugated
,
bond_is_in_ring
from
..utils.mol_to_graph
import
mol_to_bigraph
,
mol_to_complete_graph
__all__
=
[
'WLNCenterDataset'
,
'USPTOCenter'
,
'WLNRankDataset'
,
'USPTORank'
]
# Disable RDKit warnings
RDLogger
.
DisableLog
(
'rdApp.*'
)
# Atom types distinguished in featurization
atom_types
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
,
'W'
,
'Ru'
,
'Nb'
,
'Re'
,
'Te'
,
'Rh'
,
'Tc'
,
'Ba'
,
'Bi'
,
'Hf'
,
'Mo'
,
'U'
,
'Sm'
,
'Os'
,
'Ir'
,
'Ce'
,
'Gd'
,
'Ga'
,
'Cs'
]
default_node_featurizer_center
=
BaseAtomFeaturizer
({
'hv'
:
ConcatFeaturizer
(
[
partial
(
atom_type_one_hot
,
allowable_set
=
atom_types
,
encode_unknown
=
True
),
partial
(
atom_degree_one_hot
,
allowable_set
=
list
(
range
(
5
)),
encode_unknown
=
True
),
partial
(
atom_explicit_valence_one_hot
,
allowable_set
=
list
(
range
(
1
,
6
)),
encode_unknown
=
True
),
partial
(
atom_implicit_valence_one_hot
,
allowable_set
=
list
(
range
(
5
)),
encode_unknown
=
True
),
atom_is_aromatic
]
)
})
default_node_featurizer_rank
=
BaseAtomFeaturizer
({
'hv'
:
ConcatFeaturizer
(
[
partial
(
atom_type_one_hot
,
allowable_set
=
atom_types
,
encode_unknown
=
True
),
partial
(
atom_formal_charge_one_hot
,
allowable_set
=
[
-
3
,
-
2
,
-
1
,
0
,
1
,
2
],
encode_unknown
=
True
),
partial
(
atom_degree_one_hot
,
allowable_set
=
list
(
range
(
5
)),
encode_unknown
=
True
),
partial
(
atom_explicit_valence_one_hot
,
allowable_set
=
list
(
range
(
1
,
6
)),
encode_unknown
=
True
),
partial
(
atom_implicit_valence_one_hot
,
allowable_set
=
list
(
range
(
5
)),
encode_unknown
=
True
),
atom_is_aromatic
]
)
})
default_edge_featurizer_center
=
BaseBondFeaturizer
({
'he'
:
ConcatFeaturizer
([
bond_type_one_hot
,
bond_is_conjugated
,
bond_is_in_ring
]
)
})
default_edge_featurizer_rank
=
BaseBondFeaturizer
({
'he'
:
ConcatFeaturizer
([
bond_type_one_hot
,
bond_is_in_ring
]
)
})
def
default_atom_pair_featurizer
(
reactants
):
"""Featurize each pair of atoms, which will be used in updating
the edata of a complete DGLGraph.
The features include the bond type between the atoms (if any) and whether
they belong to the same molecule. It is used in the global attention mechanism.
Parameters
----------
reactants : str
SMILES for reactants
data_field : str
Key for storing the features in DGLGraph.edata. Default to 'atom_pair'
Returns
-------
float32 tensor of shape (V^2, 10)
features for each pair of atoms.
"""
# Decide the reactant membership for each atom
atom_to_reactant
=
dict
()
reactant_list
=
reactants
.
split
(
'.'
)
for
id
,
s
in
enumerate
(
reactant_list
):
mol
=
Chem
.
MolFromSmiles
(
s
)
for
atom
in
mol
.
GetAtoms
():
atom_to_reactant
[
atom
.
GetIntProp
(
'molAtomMapNumber'
)
-
1
]
=
id
# Construct mapping from atom pair to RDKit bond object
all_reactant_mol
=
Chem
.
MolFromSmiles
(
reactants
)
atom_pair_to_bond
=
dict
()
for
bond
in
all_reactant_mol
.
GetBonds
():
atom1
=
bond
.
GetBeginAtom
().
GetIntProp
(
'molAtomMapNumber'
)
-
1
atom2
=
bond
.
GetEndAtom
().
GetIntProp
(
'molAtomMapNumber'
)
-
1
atom_pair_to_bond
[(
atom1
,
atom2
)]
=
bond
atom_pair_to_bond
[(
atom2
,
atom1
)]
=
bond
def
_featurize_a_bond
(
bond
):
return
bond_type_one_hot
(
bond
)
+
bond_is_conjugated
(
bond
)
+
bond_is_in_ring
(
bond
)
features
=
[]
num_atoms
=
all_reactant_mol
.
GetNumAtoms
()
for
i
in
range
(
num_atoms
):
for
j
in
range
(
num_atoms
):
pair_feature
=
np
.
zeros
(
10
)
if
i
==
j
:
features
.
append
(
pair_feature
)
continue
bond
=
atom_pair_to_bond
.
get
((
i
,
j
),
None
)
if
bond
is
not
None
:
pair_feature
[
1
:
7
]
=
_featurize_a_bond
(
bond
)
else
:
pair_feature
[
0
]
=
1.
pair_feature
[
-
4
]
=
1.
if
atom_to_reactant
[
i
]
!=
atom_to_reactant
[
j
]
else
0.
pair_feature
[
-
3
]
=
1.
if
atom_to_reactant
[
i
]
==
atom_to_reactant
[
j
]
else
0.
pair_feature
[
-
2
]
=
1.
if
len
(
reactant_list
)
==
1
else
0.
pair_feature
[
-
1
]
=
1.
if
len
(
reactant_list
)
>
1
else
0.
features
.
append
(
pair_feature
)
return
torch
.
from_numpy
(
np
.
stack
(
features
,
axis
=
0
).
astype
(
np
.
float32
))
def
get_pair_label
(
reactants_mol
,
graph_edits
):
"""Construct labels for each pair of atoms in reaction center prediction
Parameters
----------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for all reactants in a reaction
graph_edits : str
Specifying which pairs of atoms loss a bond or form a particular bond in the reaction
Returns
-------
float32 tensor of shape (V^2, 5)
Labels constructed. V for the number of atoms in the reactants.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id
=
{
0.0
:
0
,
1
:
1
,
2
:
2
,
3
:
3
,
1.5
:
4
}
pair_to_changes
=
defaultdict
(
list
)
for
edit
in
graph_edits
.
split
(
';'
):
a1
,
a2
,
change
=
edit
.
split
(
'-'
)
atom1
=
int
(
a1
)
-
1
atom2
=
int
(
a2
)
-
1
change
=
bond_change_to_id
[
float
(
change
)]
pair_to_changes
[(
atom1
,
atom2
)].
append
(
change
)
pair_to_changes
[(
atom2
,
atom1
)].
append
(
change
)
num_atoms
=
reactants_mol
.
GetNumAtoms
()
labels
=
torch
.
zeros
((
num_atoms
,
num_atoms
,
5
))
for
pair
in
pair_to_changes
.
keys
():
i
,
j
=
pair
labels
[
i
,
j
,
pair_to_changes
[(
j
,
i
)]]
=
1.
return
labels
.
reshape
(
-
1
,
5
)
def
get_bond_changes
(
reaction
):
"""Get the bond changes in a reaction.
Parameters
----------
reaction : str
SMILES for a reaction, e.g. [CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7]
(=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5]
[c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]. It consists of reactants,
products and the atom mapping.
Returns
-------
bond_changes : set of 3-tuples
Each tuple consists of (atom1, atom2, change type)
There are 5 possible values for change type. 0 for losing the bond, and 1, 2, 3, 1.5
separately for forming a single, double, triple or aromatic bond.
"""
reactants
=
Chem
.
MolFromSmiles
(
reaction
.
split
(
'>'
)[
0
])
products
=
Chem
.
MolFromSmiles
(
reaction
.
split
(
'>'
)[
2
])
conserved_maps
=
[
a
.
GetProp
(
'molAtomMapNumber'
)
for
a
in
products
.
GetAtoms
()
if
a
.
HasProp
(
'molAtomMapNumber'
)]
bond_changes
=
set
()
# keep track of bond changes
# Look at changed bonds
bonds_prev
=
{}
for
bond
in
reactants
.
GetBonds
():
nums
=
sorted
(
[
bond
.
GetBeginAtom
().
GetProp
(
'molAtomMapNumber'
),
bond
.
GetEndAtom
().
GetProp
(
'molAtomMapNumber'
)])
if
(
nums
[
0
]
not
in
conserved_maps
)
and
(
nums
[
1
]
not
in
conserved_maps
):
continue
bonds_prev
[
'{}~{}'
.
format
(
nums
[
0
],
nums
[
1
])]
=
bond
.
GetBondTypeAsDouble
()
bonds_new
=
{}
for
bond
in
products
.
GetBonds
():
nums
=
sorted
(
[
bond
.
GetBeginAtom
().
GetProp
(
'molAtomMapNumber'
),
bond
.
GetEndAtom
().
GetProp
(
'molAtomMapNumber'
)])
bonds_new
[
'{}~{}'
.
format
(
nums
[
0
],
nums
[
1
])]
=
bond
.
GetBondTypeAsDouble
()
for
bond
in
bonds_prev
:
if
bond
not
in
bonds_new
:
# lost bond
bond_changes
.
add
((
bond
.
split
(
'~'
)[
0
],
bond
.
split
(
'~'
)[
1
],
0.0
))
else
:
if
bonds_prev
[
bond
]
!=
bonds_new
[
bond
]:
# changed bond
bond_changes
.
add
((
bond
.
split
(
'~'
)[
0
],
bond
.
split
(
'~'
)[
1
],
bonds_new
[
bond
]))
for
bond
in
bonds_new
:
if
bond
not
in
bonds_prev
:
# new bond
bond_changes
.
add
((
bond
.
split
(
'~'
)[
0
],
bond
.
split
(
'~'
)[
1
],
bonds_new
[
bond
]))
return
bond_changes
def
process_line
(
line
):
"""Process one line consisting of one reaction for working with WLN.
Parameters
----------
line : str
One reaction in one line
Returns
-------
formatted_reaction : str
Formatted reaction
"""
reaction
=
line
.
strip
()
bond_changes
=
get_bond_changes
(
reaction
)
formatted_reaction
=
'{} {}
\n
'
.
format
(
reaction
,
';'
.
join
([
'{}-{}-{}'
.
format
(
x
[
0
],
x
[
1
],
x
[
2
])
for
x
in
bond_changes
]))
return
formatted_reaction
def
process_file
(
path
,
num_processes
=
1
):
"""Pre-process a file of reactions for working with WLN.
Parameters
----------
path : str
Path to the file of reactions
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
with
open
(
path
,
'r'
)
as
input_file
:
lines
=
input_file
.
readlines
()
if
num_processes
==
1
:
results
=
[]
for
li
in
lines
:
results
.
append
(
process_line
(
li
))
else
:
with
Pool
(
processes
=
num_processes
)
as
pool
:
results
=
pool
.
map
(
process_line
,
lines
)
with
open
(
path
+
'.proc'
,
'w'
)
as
output_file
:
for
line
in
results
:
output_file
.
write
(
line
)
print
(
'Finished processing {}'
.
format
(
path
))
def
load_one_reaction
(
line
):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
reactants are not valid.
reaction : str
Reaction. None will be returned if the reactants are not valid.
graph_edits : str
Graph edits associated with the reaction. None will be returned if the
reactants are not valid.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction
,
graph_edits
=
line
.
strip
(
"
\r\n
"
).
split
()
reactants
=
reaction
.
split
(
'>'
)[
0
]
mol
=
Chem
.
MolFromSmiles
(
reactants
)
if
mol
is
None
:
return
None
,
None
,
None
# Reorder atoms according to the order specified in the atom map
atom_map_order
=
[
-
1
for
_
in
range
(
mol
.
GetNumAtoms
())]
for
j
in
range
(
mol
.
GetNumAtoms
()):
atom
=
mol
.
GetAtomWithIdx
(
j
)
atom_map_order
[
atom
.
GetIntProp
(
'molAtomMapNumber'
)
-
1
]
=
j
mol
=
rdmolops
.
RenumberAtoms
(
mol
,
atom_map_order
)
return
mol
,
reaction
,
graph_edits
class
WLNCenterDataset
(
object
):
"""Dataset for reaction center prediction with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
We will check if raw_file_path + '.proc' exists, where each line has the reaction
SMILES and the corresponding graph edits. If not, we will preprocess
the raw reaction file.
mol_graph_path : str
Path to save/load DGLGraphs for molecules.
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def
__init__
(
self
,
raw_file_path
,
mol_graph_path
,
mol_to_graph
=
mol_to_bigraph
,
node_featurizer
=
default_node_featurizer_center
,
edge_featurizer
=
default_edge_featurizer_center
,
atom_pair_featurizer
=
default_atom_pair_featurizer
,
load
=
True
,
num_processes
=
1
):
super
(
WLNCenterDataset
,
self
).
__init__
()
self
.
_atom_pair_featurizer
=
atom_pair_featurizer
self
.
atom_pair_features
=
[]
self
.
atom_pair_labels
=
[]
# Map number of nodes to a corresponding complete graph
self
.
complete_graphs
=
dict
()
path_to_reaction_file
=
raw_file_path
+
'.proc'
if
not
os
.
path
.
isfile
(
path_to_reaction_file
):
print
(
'Pre-processing graph edits from reaction data'
)
process_file
(
raw_file_path
,
num_processes
)
import
time
t0
=
time
.
time
()
full_mols
,
full_reactions
,
full_graph_edits
=
\
self
.
load_reaction_data
(
path_to_reaction_file
,
num_processes
)
print
(
'Time spent'
,
time
.
time
()
-
t0
)
if
load
and
os
.
path
.
isfile
(
mol_graph_path
):
print
(
'Loading previously saved graphs...'
)
self
.
reactant_mol_graphs
,
_
=
load_graphs
(
mol_graph_path
)
else
:
print
(
'Constructing graphs from scratch...'
)
if
num_processes
==
1
:
self
.
reactant_mol_graphs
=
[]
for
mol
in
full_mols
:
self
.
reactant_mol_graphs
.
append
(
mol_to_graph
(
mol
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
,
canonical_atom_order
=
False
))
else
:
torch
.
multiprocessing
.
set_sharing_strategy
(
'file_system'
)
with
Pool
(
processes
=
num_processes
)
as
pool
:
self
.
reactant_mol_graphs
=
pool
.
map
(
partial
(
mol_to_graph
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
,
canonical_atom_order
=
False
),
full_mols
)
save_graphs
(
mol_graph_path
,
self
.
reactant_mol_graphs
)
self
.
mols
=
full_mols
self
.
reactions
=
full_reactions
self
.
graph_edits
=
full_graph_edits
self
.
atom_pair_features
.
extend
([
None
for
_
in
range
(
len
(
self
.
mols
))])
self
.
atom_pair_labels
.
extend
([
None
for
_
in
range
(
len
(
self
.
mols
))])
def
load_reaction_data
(
self
,
file_path
,
num_processes
):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
all_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances
all_reactions : list of str
Reactions
all_graph_edits : list of str
Graph edits in the reactions.
"""
all_mols
=
[]
all_reactions
=
[]
all_graph_edits
=
[]
with
open
(
file_path
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
if
num_processes
==
1
:
results
=
[]
for
li
in
lines
:
mol
,
reaction
,
graph_edits
=
load_one_reaction
(
li
)
results
.
append
((
mol
,
reaction
,
graph_edits
))
else
:
with
Pool
(
processes
=
num_processes
)
as
pool
:
results
=
pool
.
map
(
load_one_reaction
,
lines
)
for
mol
,
reaction
,
graph_edits
in
results
:
if
mol
is
None
:
continue
all_mols
.
append
(
mol
)
all_reactions
.
append
(
reaction
)
all_graph_edits
.
append
(
graph_edits
)
return
all_mols
,
all_reactions
,
all_graph_edits
def
__len__
(
self
):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
return
len
(
self
.
mols
)
def
__getitem__
(
self
,
item
):
"""Get the i-th datapoint.
Returns
-------
str
Reaction
str
Graph edits for the reaction
DGLGraph
DGLGraph for the ith molecular graph
DGLGraph
Complete DGLGraph, which will be needed for predicting
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
float32 tensor of shape (V^2, 5)
Labels for reaction center prediction.
V for the number of atoms in the reactants.
"""
mol
=
self
.
mols
[
item
]
num_atoms
=
mol
.
GetNumAtoms
()
if
num_atoms
not
in
self
.
complete_graphs
:
self
.
complete_graphs
[
num_atoms
]
=
mol_to_complete_graph
(
mol
,
add_self_loop
=
True
,
canonical_atom_order
=
False
)
if
self
.
atom_pair_features
[
item
]
is
None
:
reactants
=
self
.
reactions
[
item
].
split
(
'>'
)[
0
]
self
.
atom_pair_features
[
item
]
=
self
.
_atom_pair_featurizer
(
reactants
)
if
self
.
atom_pair_labels
[
item
]
is
None
:
self
.
atom_pair_labels
[
item
]
=
get_pair_label
(
mol
,
self
.
graph_edits
[
item
])
return
self
.
reactions
[
item
],
self
.
graph_edits
[
item
],
\
self
.
reactant_mol_graphs
[
item
],
\
self
.
complete_graphs
[
num_atoms
],
\
self
.
atom_pair_features
[
item
],
\
self
.
atom_pair_labels
[
item
]
class
USPTOCenter
(
WLNCenterDataset
):
"""USPTO dataset for reaction center prediction.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def
__init__
(
self
,
subset
,
mol_to_graph
=
mol_to_bigraph
,
node_featurizer
=
default_node_featurizer_center
,
edge_featurizer
=
default_edge_featurizer_center
,
atom_pair_featurizer
=
default_atom_pair_featurizer
,
load
=
True
,
num_processes
=
1
):
assert
subset
in
[
'train'
,
'val'
,
'test'
],
\
'Expect subset to be "train" or "val" or "test", got {}'
.
format
(
subset
)
print
(
'Preparing {} subset of USPTO for reaction center prediction.'
.
format
(
subset
))
self
.
_subset
=
subset
if
subset
==
'val'
:
subset
=
'valid'
self
.
_url
=
'dataset/uspto.zip'
data_path
=
get_download_dir
()
+
'/uspto.zip'
extracted_data_path
=
get_download_dir
()
+
'/uspto'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
extract_archive
(
data_path
,
extracted_data_path
)
super
(
USPTOCenter
,
self
).
__init__
(
raw_file_path
=
extracted_data_path
+
'/{}.txt'
.
format
(
subset
),
mol_graph_path
=
extracted_data_path
+
'/{}_mol_graphs.bin'
.
format
(
subset
),
mol_to_graph
=
mol_to_graph
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
,
atom_pair_featurizer
=
atom_pair_featurizer
,
load
=
load
,
num_processes
=
num_processes
)
@
property
def
subset
(
self
):
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return
self
.
_subset
def
mkdir_p
(
path
):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try
:
os
.
makedirs
(
path
)
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
pass
else
:
raise
def
load_one_reaction_rank
(
line
):
"""Load one reaction and check if the reactants are valid.
Parameters
----------
line : str
One reaction and the associated graph edits
Returns
-------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants. None will be returned if the
line is not valid.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product. None will be returned if the line is not valid.
reaction_real_bond_changes : list of 3-tuples
Real bond changes in the reaction. Each tuple is of form (atom1, atom2, change_type). For
change_type, 0.0 stands for losing a bond, 1.0, 2.0, 3.0 and 1.5 separately stands for
forming a single, double, triple or aromatic bond.
"""
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction
,
graph_edits
=
line
.
strip
(
"
\r\n
"
).
split
()
reactants
,
_
,
product
=
reaction
.
split
(
'>'
)
reactants_mol
=
Chem
.
MolFromSmiles
(
reactants
)
if
reactants_mol
is
None
:
return
None
,
None
,
None
,
None
,
None
product_mol
=
Chem
.
MolFromSmiles
(
product
)
if
product_mol
is
None
:
return
None
,
None
,
None
,
None
,
None
# Reorder atoms according to the order specified in the atom map
atom_map_order
=
[
-
1
for
_
in
range
(
reactants_mol
.
GetNumAtoms
())]
for
j
in
range
(
reactants_mol
.
GetNumAtoms
()):
atom
=
reactants_mol
.
GetAtomWithIdx
(
j
)
atom_map_order
[
atom
.
GetIntProp
(
'molAtomMapNumber'
)
-
1
]
=
j
reactants_mol
=
rdmolops
.
RenumberAtoms
(
reactants_mol
,
atom_map_order
)
reaction_real_bond_changes
=
[]
for
changed_bond
in
graph_edits
.
split
(
';'
):
atom1
,
atom2
,
change_type
=
changed_bond
.
split
(
'-'
)
atom1
,
atom2
=
int
(
atom1
)
-
1
,
int
(
atom2
)
-
1
reaction_real_bond_changes
.
append
(
(
min
(
atom1
,
atom2
),
max
(
atom1
,
atom2
),
float
(
change_type
)))
return
reactants_mol
,
product_mol
,
reaction_real_bond_changes
def
load_candidate_bond_changes_for_one_reaction
(
line
):
"""Load candidate bond changes for a reaction
Parameters
----------
line : str
Candidate bond changes separated by ;. Each candidate bond change takes the
form of atom1, atom2, change_type and change_score.
Returns
-------
list of 4-tuples
Loaded candidate bond changes.
"""
reaction_candidate_bond_changes
=
[]
elements
=
line
.
strip
().
split
(
';'
)[:
-
1
]
for
candidate
in
elements
:
atom1
,
atom2
,
change_type
,
score
=
candidate
.
split
(
' '
)
atom1
,
atom2
=
int
(
atom1
)
-
1
,
int
(
atom2
)
-
1
reaction_candidate_bond_changes
.
append
((
min
(
atom1
,
atom2
),
max
(
atom1
,
atom2
),
float
(
change_type
),
float
(
score
)))
return
reaction_candidate_bond_changes
def
bookkeep_reactant
(
mol
,
candidate_pairs
):
"""Bookkeep reaction-related information of reactants.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants.
candidate_pairs : list of 2-tuples
Pairs of atoms that ranked high by a model for reaction center prediction.
By assumption, the two atoms are different and the first atom has a smaller
index than the second.
Returns
-------
info : dict
Reaction-related information of reactants
"""
num_atoms
=
mol
.
GetNumAtoms
()
info
=
{
# free valence of atoms
'free_val'
:
[
0
for
_
in
range
(
num_atoms
)],
# Whether it is a carbon atom
'is_c'
:
[
False
for
_
in
range
(
num_atoms
)],
# Whether it is a carbon atom connected to a nitrogen atom in pyridine
'is_c2_of_pyridine'
:
[
False
for
_
in
range
(
num_atoms
)],
# Whether it is a phosphorous atom
'is_p'
:
[
False
for
_
in
range
(
num_atoms
)],
# Whether it is a sulfur atom
'is_s'
:
[
False
for
_
in
range
(
num_atoms
)],
# Whether it is an oxygen atom
'is_o'
:
[
False
for
_
in
range
(
num_atoms
)],
# Whether it is a nitrogen atom
'is_n'
:
[
False
for
_
in
range
(
num_atoms
)],
'pair_to_bond_val'
:
dict
(),
'ring_bonds'
:
set
()
}
# bookkeep atoms
for
j
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
info
[
'free_val'
][
j
]
+=
atom
.
GetTotalNumHs
()
+
abs
(
atom
.
GetFormalCharge
())
# An aromatic carbon atom next to an aromatic nitrogen atom can get a
# carbonyl b/c of bookkeeping of hydroxypyridines
if
atom
.
GetSymbol
()
==
'C'
:
info
[
'is_c'
][
j
]
=
True
if
atom
.
GetIsAromatic
():
for
nbr
in
atom
.
GetNeighbors
():
if
nbr
.
GetSymbol
()
==
'N'
and
nbr
.
GetDegree
()
==
2
:
info
[
'is_c2_of_pyridine'
][
j
]
=
True
break
# A nitrogen atom should be allowed to become positively charged
elif
atom
.
GetSymbol
()
==
'N'
:
info
[
'free_val'
][
j
]
+=
1
-
atom
.
GetFormalCharge
()
info
[
'is_n'
][
j
]
=
True
# Phosphorous atoms can form a phosphonium
elif
atom
.
GetSymbol
()
==
'P'
:
info
[
'free_val'
][
j
]
+=
1
-
atom
.
GetFormalCharge
()
info
[
'is_p'
][
j
]
=
True
elif
atom
.
GetSymbol
()
==
'O'
:
info
[
'is_o'
][
j
]
=
True
elif
atom
.
GetSymbol
()
==
'S'
:
info
[
'is_s'
][
j
]
=
True
# bookkeep bonds
for
bond
in
mol
.
GetBonds
():
atom1
,
atom2
=
bond
.
GetBeginAtomIdx
(),
bond
.
GetEndAtomIdx
()
atom1
,
atom2
=
min
(
atom1
,
atom2
),
max
(
atom1
,
atom2
)
type_val
=
bond
.
GetBondTypeAsDouble
()
info
[
'pair_to_bond_val'
][(
atom1
,
atom2
)]
=
type_val
if
(
atom1
,
atom2
)
in
candidate_pairs
:
info
[
'free_val'
][
atom1
]
+=
type_val
info
[
'free_val'
][
atom2
]
+=
type_val
if
bond
.
IsInRing
():
info
[
'ring_bonds'
].
add
((
atom1
,
atom2
))
return
info
def
bookkeep_product
(
mol
):
"""Bookkeep reaction-related information of atoms/bonds in products
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for products.
Returns
-------
info : dict
Reaction-related information of atoms/bonds in products
"""
info
=
{
'atoms'
:
set
()
}
for
atom
in
mol
.
GetAtoms
():
info
[
'atoms'
].
add
(
atom
.
GetAtomMapNum
()
-
1
)
return
info
def
is_connected_change_combo
(
combo_ids
,
cand_change_adj
):
"""Check whether the combo of bond changes yields a connected component.
Parameters
----------
combo_ids : tuple of int
Ids for bond changes in the combination.
cand_change_adj : bool ndarray of shape (N, N)
Adjacency matrix for candidate bond changes. Two candidate bond
changes are considered adjacent if they share a common atom.
* N for the number of candidate bond changes.
Returns
-------
bool
Whether the combo of bond changes yields a connected component
"""
if
len
(
combo_ids
)
==
1
:
return
True
multi_hop_adj
=
np
.
linalg
.
matrix_power
(
cand_change_adj
[
combo_ids
,
:][:,
combo_ids
],
len
(
combo_ids
)
-
1
)
# The combo is connected if the distance between
# any pair of bond changes is within len(combo) - 1
return
np
.
all
(
multi_hop_adj
)
def
is_valid_combo
(
combo_changes
,
reactant_info
):
"""Whether the combo of bond changes is chemically valid.
Parameters
----------
combo_changes : list of 4-tuples
Each tuple consists of atom1, atom2, type of bond change (in the form of related
valence) and score for the change.
reactant_info : dict
Reaction-related information of reactants
Returns
-------
bool
Whether the combo of bond changes is chemically valid.
"""
num_atoms
=
len
(
reactant_info
[
'free_val'
])
force_even_parity
=
np
.
zeros
((
num_atoms
,),
dtype
=
bool
)
force_odd_parity
=
np
.
zeros
((
num_atoms
,),
dtype
=
bool
)
pair_seen
=
defaultdict
(
bool
)
free_val_tmp
=
reactant_info
[
'free_val'
].
copy
()
for
(
atom1
,
atom2
,
change_type
,
score
)
in
combo_changes
:
if
pair_seen
[(
atom1
,
atom2
)]:
# A pair of atoms cannot have two types of changes. Even if we
# randomly pick one, that will be reduced to a combo of less changes
return
False
pair_seen
[(
atom1
,
atom2
)]
=
True
# Special valence rules
atom1_type_val
=
atom2_type_val
=
change_type
if
change_type
==
2
:
# to form a double bond
if
reactant_info
[
'is_o'
][
atom1
]:
if
reactant_info
[
'is_c2_of_pyridine'
][
atom2
]:
atom2_type_val
=
1.
elif
reactant_info
[
'is_p'
][
atom2
]:
# don't count information of =o toward valence
# but require odd valence parity
atom2_type_val
=
0.
force_odd_parity
[
atom2
]
=
True
elif
reactant_info
[
'is_s'
][
atom2
]:
atom2_type_val
=
0.
force_even_parity
[
atom2
]
=
True
elif
reactant_info
[
'is_o'
][
atom2
]:
if
reactant_info
[
'is_c2_of_pyridine'
][
atom1
]:
atom1_type_val
=
1.
elif
reactant_info
[
'is_p'
][
atom1
]:
atom1_type_val
=
0.
force_odd_parity
[
atom1
]
=
True
elif
reactant_info
[
'is_s'
][
atom1
]:
atom1_type_val
=
0.
force_even_parity
[
atom1
]
=
True
elif
reactant_info
[
'is_n'
][
atom1
]
and
reactant_info
[
'is_p'
][
atom2
]:
atom2_type_val
=
0.
force_odd_parity
[
atom2
]
=
True
elif
reactant_info
[
'is_n'
][
atom2
]
and
reactant_info
[
'is_p'
][
atom1
]:
atom1_type_val
=
0.
force_odd_parity
[
atom1
]
=
True
elif
reactant_info
[
'is_p'
][
atom1
]
and
reactant_info
[
'is_c'
][
atom2
]:
atom1_type_val
=
0.
force_odd_parity
[
atom1
]
=
True
elif
reactant_info
[
'is_p'
][
atom2
]
and
reactant_info
[
'is_c'
][
atom1
]:
atom2_type_val
=
0.
force_odd_parity
[
atom2
]
=
True
reactant_pair_val
=
reactant_info
[
'pair_to_bond_val'
].
get
((
atom1
,
atom2
),
None
)
if
reactant_pair_val
is
not
None
:
free_val_tmp
[
atom1
]
+=
reactant_pair_val
-
atom1_type_val
free_val_tmp
[
atom2
]
+=
reactant_pair_val
-
atom2_type_val
else
:
free_val_tmp
[
atom1
]
-=
atom1_type_val
free_val_tmp
[
atom2
]
-=
atom2_type_val
free_val_tmp
=
np
.
array
(
free_val_tmp
)
# False if 1) too many connections 2) sulfur valence not even
# 3) phosphorous valence not odd
if
any
(
free_val_tmp
<
0
)
or
\
any
(
aval
%
2
!=
0
for
aval
in
free_val_tmp
[
force_even_parity
])
or
\
any
(
aval
%
2
!=
1
for
aval
in
free_val_tmp
[
force_odd_parity
]):
return
False
return
True
def
edit_mol
(
reactant_mols
,
edits
,
product_info
):
"""Simulate reaction via graph editing
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
bond_change_to_type
=
{
1
:
Chem
.
rdchem
.
BondType
.
SINGLE
,
2
:
Chem
.
rdchem
.
BondType
.
DOUBLE
,
3
:
Chem
.
rdchem
.
BondType
.
TRIPLE
,
1.5
:
Chem
.
rdchem
.
BondType
.
AROMATIC
}
new_mol
=
Chem
.
RWMol
(
reactant_mols
)
[
atom
.
SetNumExplicitHs
(
0
)
for
atom
in
new_mol
.
GetAtoms
()]
for
atom1
,
atom2
,
change_type
,
score
in
edits
:
bond
=
new_mol
.
GetBondBetweenAtoms
(
atom1
,
atom2
)
if
bond
is
not
None
:
new_mol
.
RemoveBond
(
atom1
,
atom2
)
if
change_type
>
0
:
new_mol
.
AddBond
(
atom1
,
atom2
,
bond_change_to_type
[
change_type
])
pred_mol
=
new_mol
.
GetMol
()
pred_smiles
=
Chem
.
MolToSmiles
(
pred_mol
)
pred_list
=
pred_smiles
.
split
(
'.'
)
pred_mols
=
[]
for
pred_smiles
in
pred_list
:
mol
=
Chem
.
MolFromSmiles
(
pred_smiles
)
if
mol
is
None
:
continue
atom_set
=
set
([
atom
.
GetAtomMapNum
()
-
1
for
atom
in
mol
.
GetAtoms
()])
if
len
(
atom_set
&
product_info
[
'atoms'
])
==
0
:
continue
for
atom
in
mol
.
GetAtoms
():
atom
.
SetAtomMapNum
(
0
)
pred_mols
.
append
(
mol
)
return
'.'
.
join
(
sorted
([
Chem
.
MolToSmiles
(
mol
)
for
mol
in
pred_mols
]))
def
get_product_smiles
(
reactant_mols
,
edits
,
product_info
):
"""Get the product smiles of the reaction
Parameters
----------
reactant_mols : rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
edits : list of 4-tuples
Bond changes for getting the product out of the reactants in a reaction.
Each 4-tuple is of form (atom1, atom2, change_type, score), where atom1
and atom2 are the end atoms to form or lose a bond, change_type is the
type of bond change and score represents the confidence for the bond change
by a model.
product_info : dict
proeduct_info['atoms'] gives a set of atom ids in the ground truth product molecule.
Returns
-------
str
SMILES for the main products
"""
smiles
=
edit_mol
(
reactant_mols
,
edits
,
product_info
)
if
len
(
smiles
)
!=
0
:
return
smiles
try
:
Chem
.
Kekulize
(
reactant_mols
)
except
Exception
as
e
:
return
smiles
return
edit_mol
(
reactant_mols
,
edits
,
product_info
)
def
generate_valid_candidate_combos
():
return
NotImplementedError
def
pre_process_one_reaction
(
info
,
num_candidate_bond_changes
,
max_num_bond_changes
,
max_num_change_combos
,
mode
):
"""Pre-process one reaction for candidate ranking.
Parameters
----------
info : 4-tuple
* candidate_bond_changes : list of tuples
The candidate bond changes for the reaction
* real_bond_changes : list of tuples
The real bond changes for the reaction
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants
* product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for product
num_candidate_bond_changes : int
Number of candidate bond changes to consider for the ground truth reaction.
max_num_bond_changes : int
Maximum number of bond changes per reaction.
max_num_change_combos : int
Number of bond change combos to consider for each reaction.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
mode : str
Whether the dataset is to be used for training, validation or test.
Returns
-------
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for combos.
reactant_info : dict
Reaction-related information of reactants.
"""
assert
mode
in
[
'train'
,
'val'
,
'test'
],
\
"Expect mode to be 'train' or 'val' or 'test', got {}"
.
format
(
mode
)
candidate_bond_changes_
,
real_bond_changes
,
reactant_mol
,
product_mol
=
info
candidate_pairs
=
[(
atom1
,
atom2
)
for
(
atom1
,
atom2
,
_
,
_
)
in
candidate_bond_changes_
]
reactant_info
=
bookkeep_reactant
(
reactant_mol
,
candidate_pairs
)
if
mode
==
'train'
:
product_info
=
bookkeep_product
(
product_mol
)
# Filter out candidate new bonds already in reactants
candidate_bond_changes
=
[]
count
=
0
for
(
atom1
,
atom2
,
change_type
,
score
)
in
candidate_bond_changes_
:
if
((
atom1
,
atom2
)
not
in
reactant_info
[
'pair_to_bond_val'
])
or
\
(
reactant_info
[
'pair_to_bond_val'
][(
atom1
,
atom2
)]
!=
change_type
):
candidate_bond_changes
.
append
((
atom1
,
atom2
,
change_type
,
score
))
count
+=
1
if
count
==
num_candidate_bond_changes
:
break
# Check if two bond changes have atom in common
cand_change_adj
=
np
.
eye
(
len
(
candidate_bond_changes
),
dtype
=
bool
)
for
i
in
range
(
len
(
candidate_bond_changes
)):
atom1_1
,
atom1_2
,
_
,
_
=
candidate_bond_changes
[
i
]
for
j
in
range
(
i
+
1
,
len
(
candidate_bond_changes
)):
atom2_1
,
atom2_2
,
_
,
_
=
candidate_bond_changes
[
j
]
if
atom1_1
==
atom2_1
or
atom1_1
==
atom2_2
or
\
atom1_2
==
atom2_1
or
atom1_2
==
atom2_2
:
cand_change_adj
[
i
,
j
]
=
cand_change_adj
[
j
,
i
]
=
True
# Enumerate combinations of k candidate bond changes and record
# those that are connected and chemically valid
valid_candidate_combos
=
[]
cand_change_ids
=
range
(
len
(
candidate_bond_changes
))
for
k
in
range
(
1
,
max_num_bond_changes
+
1
):
for
combo_ids
in
combinations
(
cand_change_ids
,
k
):
# Check if the changed bonds form a connected component
if
not
is_connected_change_combo
(
combo_ids
,
cand_change_adj
):
continue
combo_changes
=
[
candidate_bond_changes
[
j
]
for
j
in
combo_ids
]
# Check if the combo is chemically valid
if
is_valid_combo
(
combo_changes
,
reactant_info
):
valid_candidate_combos
.
append
(
combo_changes
)
if
mode
==
'train'
:
random
.
shuffle
(
valid_candidate_combos
)
# Index for the combo of candidate bond changes
# that is equivalent to the gold combo
real_combo_id
=
-
1
for
j
,
combo_changes
in
enumerate
(
valid_candidate_combos
):
if
set
([(
atom1
,
atom2
,
change_type
)
for
(
atom1
,
atom2
,
change_type
,
score
)
in
combo_changes
])
==
\
set
(
real_bond_changes
):
real_combo_id
=
j
break
# If we fail to find the real combo, make it the first entry
if
real_combo_id
==
-
1
:
valid_candidate_combos
=
\
[[(
atom1
,
atom2
,
change_type
,
0.0
)
for
(
atom1
,
atom2
,
change_type
)
in
real_bond_changes
]]
+
\
valid_candidate_combos
else
:
valid_candidate_combos
[
0
],
valid_candidate_combos
[
real_combo_id
]
=
\
valid_candidate_combos
[
real_combo_id
],
valid_candidate_combos
[
0
]
product_smiles
=
get_product_smiles
(
reactant_mol
,
valid_candidate_combos
[
0
],
product_info
)
if
len
(
product_smiles
)
>
0
:
# Remove combos yielding duplicate products
product_smiles
=
set
([
product_smiles
])
new_candidate_combos
=
[
valid_candidate_combos
[
0
]]
count
=
0
for
combo
in
valid_candidate_combos
[
1
:]:
smiles
=
get_product_smiles
(
reactant_mol
,
combo
,
product_info
)
if
smiles
in
product_smiles
or
len
(
smiles
)
==
0
:
continue
product_smiles
.
add
(
smiles
)
new_candidate_combos
.
append
(
combo
)
count
+=
1
if
count
==
max_num_change_combos
:
break
valid_candidate_combos
=
new_candidate_combos
valid_candidate_combos
=
valid_candidate_combos
[:
max_num_change_combos
]
return
valid_candidate_combos
,
candidate_bond_changes
,
reactant_info
def
featurize_nodes_and_compute_combo_scores
(
node_featurizer
,
reactant_mol
,
valid_candidate_combos
):
"""Featurize atoms in reactants and compute scores for combos of bond changes
Parameters
----------
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
Returns
-------
node_feats : float32 tensor of shape (N, M)
Node features for reactants, N for the number of nodes and M for the feature size
combo_bias : float32 tensor of shape (B, 1)
Scores for combos of bond changes, B equals len(valid_candidate_combos)
"""
node_feats
=
node_featurizer
(
reactant_mol
)[
'hv'
]
combo_bias
=
torch
.
zeros
(
len
(
valid_candidate_combos
),
1
).
float
()
for
combo_id
,
combo
in
enumerate
(
valid_candidate_combos
):
combo_bias
[
combo_id
]
=
sum
([
score
for
(
atom1
,
atom2
,
change_type
,
score
)
in
combo
])
return
node_feats
,
combo_bias
def
construct_graphs_rank
(
info
,
edge_featurizer
):
"""Construct graphs for reactants and candidate products in a reaction and featurize
their edges
Parameters
----------
info : 4-tuple
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants in a reaction
* candidate_combos : list
candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
* candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for candidate products
* reactant_info : dict
Reaction-related information of reactants.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
reaction_graphs : list of DGLGraphs
DGLGraphs for reactants and candidate products with edge features in edata['he'],
where the first graph is for reactants.
"""
reactant_mol
,
candidate_combos
,
candidate_bond_changes
,
reactant_info
=
info
# Graphs for reactants and candidate products
reaction_graphs
=
[]
# Get graph for the reactants
reactant_graph
=
mol_to_bigraph
(
reactant_mol
,
edge_featurizer
=
edge_featurizer
,
canonical_atom_order
=
False
)
reaction_graphs
.
append
(
reactant_graph
)
candidate_bond_changes_no_score
=
[
(
atom1
,
atom2
,
change_type
)
for
(
atom1
,
atom2
,
change_type
,
score
)
in
candidate_bond_changes
]
# Prepare common components across all candidate products
breaking_reactant_neighbors
=
[]
common_src_list
=
[]
common_dst_list
=
[]
common_edge_feats
=
[]
num_bonds
=
reactant_mol
.
GetNumBonds
()
for
j
in
range
(
num_bonds
):
bond
=
reactant_mol
.
GetBondWithIdx
(
j
)
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
u_sort
,
v_sort
=
min
(
u
,
v
),
max
(
u
,
v
)
# Whether a bond in reactants might get broken
if
(
u_sort
,
v_sort
,
0.0
)
not
in
candidate_bond_changes_no_score
:
common_src_list
.
extend
([
u
,
v
])
common_dst_list
.
extend
([
v
,
u
])
common_edge_feats
.
extend
([
reactant_graph
.
edata
[
'he'
][
2
*
j
],
reactant_graph
.
edata
[
'he'
][
2
*
j
+
1
]])
else
:
breaking_reactant_neighbors
.
append
((
u_sort
,
v_sort
,
bond
.
GetBondTypeAsDouble
()))
for
combo
in
candidate_combos
:
combo_src_list
=
deepcopy
(
common_src_list
)
combo_dst_list
=
deepcopy
(
common_dst_list
)
combo_edge_feats
=
deepcopy
(
common_edge_feats
)
candidate_bond_end_atoms
=
[
(
atom1
,
atom2
)
for
(
atom1
,
atom2
,
change_type
,
score
)
in
combo
]
for
(
atom1
,
atom2
,
change_type
)
in
breaking_reactant_neighbors
:
if
(
atom1
,
atom2
)
not
in
candidate_bond_end_atoms
:
# If a bond might be broken in some other combos but not this,
# add it as a negative sample
combo
.
append
((
atom1
,
atom2
,
change_type
,
0.0
))
for
(
atom1
,
atom2
,
change_type
,
score
)
in
combo
:
if
change_type
==
0
:
continue
combo_src_list
.
extend
([
atom1
,
atom2
])
combo_dst_list
.
extend
([
atom2
,
atom1
])
feats
=
one_hot_encoding
(
change_type
,
[
1.0
,
2.0
,
3.0
,
1.5
,
-
1
])
if
(
atom1
,
atom2
)
in
reactant_info
[
'ring_bonds'
]:
feats
[
-
1
]
=
1
feats
=
torch
.
tensor
(
feats
).
float
()
combo_edge_feats
.
extend
([
feats
,
feats
.
clone
()])
combo_edge_feats
=
torch
.
stack
(
combo_edge_feats
,
dim
=
0
)
combo_graph
=
DGLGraph
()
combo_graph
.
add_nodes
(
reactant_graph
.
number_of_nodes
())
combo_graph
.
add_edges
(
combo_src_list
,
combo_dst_list
)
combo_graph
.
edata
[
'he'
]
=
combo_edge_feats
reaction_graphs
.
append
(
combo_graph
)
return
reaction_graphs
class
WLNRankDataset
(
object
):
"""Dataset for ranking candidate products with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
mode : str
'train', 'val', or 'test', indicating whether the dataset is used for training,
validation or test.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom formal charge, atom degree, atom explicit valence, atom implicit valence,
aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type
and whether bond is in ring.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def
__init__
(
self
,
raw_file_path
,
candidate_bond_path
,
mode
,
node_featurizer
=
default_node_featurizer_rank
,
edge_featurizer
=
default_edge_featurizer_rank
,
size_cutoff
=
100
,
max_num_changes_per_reaction
=
5
,
num_candidate_bond_changes
=
16
,
max_num_change_combos_per_reaction
=
150
,
num_processes
=
1
):
super
(
WLNRankDataset
,
self
).
__init__
()
assert
mode
in
[
'train'
,
'val'
,
'test'
],
\
"Expect mode to be 'train' or 'val' or 'test', got {}"
.
format
(
mode
)
self
.
mode
=
mode
self
.
ignore_large_samples
=
False
self
.
size_cutoff
=
size_cutoff
path_to_reaction_file
=
raw_file_path
+
'.proc'
if
not
os
.
path
.
isfile
(
path_to_reaction_file
):
print
(
'Pre-processing graph edits from reaction data'
)
process_file
(
raw_file_path
,
num_processes
)
self
.
reactant_mols
,
self
.
product_mols
,
self
.
real_bond_changes
,
\
self
.
ids_for_small_samples
=
self
.
load_reaction_data
(
path_to_reaction_file
,
num_processes
)
self
.
candidate_bond_changes
=
self
.
load_candidate_bond_changes
(
candidate_bond_path
)
self
.
num_candidate_bond_changes
=
num_candidate_bond_changes
self
.
max_num_changes_per_reaction
=
max_num_changes_per_reaction
self
.
max_num_change_combos_per_reaction
=
max_num_change_combos_per_reaction
self
.
node_featurizer
=
node_featurizer
self
.
edge_featurizer
=
edge_featurizer
def
load_reaction_data
(
self
,
file_path
,
num_processes
):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
num_processes : int
Number of processes to use for data pre-processing.
Returns
-------
all_reactant_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for reactants.
all_product_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances for products if the dataset is for training and
None otherwise.
all_real_bond_changes : list of list
``all_real_bond_changes[i]`` gives a list of tuples, which are ground
truth bond changes for a reaction.
ids_for_small_samples : list of int
Indices for reactions whose reactants do not contain too many atoms
"""
print
(
'Stage 1/2: loading reaction data...'
)
all_reactant_mols
=
[]
all_product_mols
=
[]
all_real_bond_changes
=
[]
ids_for_small_samples
=
[]
with
open
(
file_path
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
def
_update_from_line
(
id
,
loaded_result
):
reactants_mol
,
product_mol
,
reaction_real_bond_changes
=
loaded_result
if
reactants_mol
is
None
:
return
all_product_mols
.
append
(
product_mol
)
all_reactant_mols
.
append
(
reactants_mol
)
all_real_bond_changes
.
append
(
reaction_real_bond_changes
)
if
reactants_mol
.
GetNumAtoms
()
<=
self
.
size_cutoff
:
ids_for_small_samples
.
append
(
id
)
if
num_processes
==
1
:
for
id
,
li
in
enumerate
(
tqdm
(
lines
)):
loaded_line
=
load_one_reaction_rank
(
li
)
_update_from_line
(
id
,
loaded_line
)
else
:
with
Pool
(
processes
=
num_processes
)
as
pool
:
results
=
pool
.
map
(
load_one_reaction_rank
,
lines
,
chunksize
=
len
(
lines
)
//
num_processes
)
for
id
in
range
(
len
(
lines
)):
_update_from_line
(
id
,
results
[
id
])
return
all_reactant_mols
,
all_product_mols
,
all_real_bond_changes
,
ids_for_small_samples
def
load_candidate_bond_changes
(
self
,
file_path
):
"""Load candidate bond changes predicted by a WLN for reaction center prediction.
Parameters
----------
file_path : str
Path to a file of candidate bond changes for each reaction.
Returns
-------
all_candidate_bond_changes : list of list
``all_candidate_bond_changes[i]`` gives a list of tuples, which are candidate
bond changes for a reaction.
"""
print
(
'Stage 2/2: loading candidate bond changes...'
)
with
open
(
file_path
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
all_candidate_bond_changes
=
[]
for
li
in
tqdm
(
lines
):
all_candidate_bond_changes
.
append
(
load_candidate_bond_changes_for_one_reaction
(
li
))
return
all_candidate_bond_changes
def
ignore_large
(
self
,
ignore
=
True
):
"""Whether to ignore reactions where reactants contain too many atoms.
Parameters
----------
ignore : bool
If ``ignore``, reactions where reactants contain too many atoms will be ignored.
"""
self
.
ignore_large_samples
=
ignore
def
__len__
(
self
):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
if
self
.
ignore_large_samples
:
return
len
(
self
.
ids_for_small_samples
)
else
:
return
len
(
self
.
reactant_mols
)
def
__getitem__
(
self
,
item
):
"""Get the i-th datapoint.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
list of B + 1 DGLGraph
The first entry in the list is the DGLGraph for the reactants and the rest are
DGLGraphs for candidate products. Each DGLGraph has edge features in edata['he'] and
node features in ndata['hv'].
candidate_scores : float32 tensor of shape (B, 1)
The sum of scores for bond changes in each combo, where B is the number of combos.
labels : int64 tensor of shape (1, 1), optional
Index for the true candidate product, which is always 0 with pre-processing. This is
returned only when we are not in the training mode.
valid_candidate_combos : list, optional
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction. Each tuple is of form (atom1, atom2,
change_type, score). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants
real_bond_changes : list of tuples
Ground truth bond changes in a reaction. Each tuple is of form (atom1, atom2,
change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product
"""
if
self
.
ignore_large_samples
:
item
=
self
.
ids_for_small_samples
[
item
]
raw_candidate_bond_changes
=
self
.
candidate_bond_changes
[
item
]
real_bond_changes
=
self
.
real_bond_changes
[
item
]
reactant_mol
=
self
.
reactant_mols
[
item
]
product_mol
=
self
.
product_mols
[
item
]
# Get valid candidate products, candidate bond changes considered and reactant info
valid_candidate_combos
,
candidate_bond_changes
,
reactant_info
=
\
pre_process_one_reaction
(
(
raw_candidate_bond_changes
,
real_bond_changes
,
reactant_mol
,
product_mol
),
self
.
num_candidate_bond_changes
,
self
.
max_num_changes_per_reaction
,
self
.
max_num_change_combos_per_reaction
,
self
.
mode
)
# Construct DGLGraphs and featurize their edges
g_list
=
construct_graphs_rank
(
(
reactant_mol
,
valid_candidate_combos
,
candidate_bond_changes
,
reactant_info
),
self
.
edge_featurizer
)
# Get node features and candidate scores
node_feats
,
candidate_scores
=
featurize_nodes_and_compute_combo_scores
(
self
.
node_featurizer
,
reactant_mol
,
valid_candidate_combos
)
for
g
in
g_list
:
g
.
ndata
[
'hv'
]
=
node_feats
if
self
.
mode
==
'train'
:
labels
=
torch
.
zeros
(
1
,
1
).
long
()
return
g_list
,
candidate_scores
,
labels
else
:
reactant_mol
=
self
.
reactant_mols
[
item
]
real_bond_changes
=
self
.
real_bond_changes
[
item
]
product_mol
=
self
.
product_mols
[
item
]
return
g_list
,
candidate_scores
,
valid_candidate_combos
,
\
reactant_mol
,
real_bond_changes
,
product_mol
class
USPTORank
(
WLNRankDataset
):
"""USPTO dataset for ranking candidate products.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
candidate_bond_path : str
Path to the candidate bond changes for product enumeration, where each line is
candidate bond changes for a reaction by a WLN for reaction center prediction.
size_cutoff : int
By calling ``.ignore_large(True)``, we can optionally ignore reactions whose reactants
contain more than ``size_cutoff`` atoms. Default to 100.
max_num_changes_per_reaction : int
Maximum number of bond changes per reaction. Default to 5.
num_candidate_bond_changes : int
Number of candidate bond changes to consider for each ground truth reaction.
Default to 16.
max_num_change_combos_per_reaction : int
Number of bond change combos to consider for each reaction. Default to 150.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
"""
def
__init__
(
self
,
subset
,
candidate_bond_path
,
size_cutoff
=
100
,
max_num_changes_per_reaction
=
5
,
num_candidate_bond_changes
=
16
,
max_num_change_combos_per_reaction
=
150
,
num_processes
=
1
):
assert
subset
in
[
'train'
,
'val'
,
'test'
],
\
'Expect subset to be "train" or "val" or "test", got {}'
.
format
(
subset
)
print
(
'Preparing {} subset of USPTO for product candidate ranking.'
.
format
(
subset
))
self
.
_subset
=
subset
if
subset
==
'val'
:
mode
=
'val'
subset
=
'valid'
else
:
mode
=
subset
self
.
_url
=
'dataset/uspto.zip'
data_path
=
get_download_dir
()
+
'/uspto.zip'
extracted_data_path
=
get_download_dir
()
+
'/uspto'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
extract_archive
(
data_path
,
extracted_data_path
)
super
(
USPTORank
,
self
).
__init__
(
raw_file_path
=
extracted_data_path
+
'/{}.txt'
.
format
(
subset
),
candidate_bond_path
=
candidate_bond_path
,
mode
=
mode
,
size_cutoff
=
size_cutoff
,
max_num_changes_per_reaction
=
max_num_changes_per_reaction
,
num_candidate_bond_changes
=
num_candidate_bond_changes
,
max_num_change_combos_per_reaction
=
max_num_change_combos_per_reaction
,
num_processes
=
num_processes
)
@
property
def
subset
(
self
):
"""Get the subset used for USPTOCenter
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return
self
.
_subset
apps/life_sci/python/dgllife/libinfo.py
deleted
100644 → 0
View file @
94c67203
"""Information for the library."""
# current version
__version__
=
'0.2.2'
apps/life_sci/python/dgllife/model/__init__.py
deleted
100644 → 0
View file @
94c67203
"""Model architectures and components at different levels."""
from
.gnn
import
*
from
.readout
import
*
from
.model_zoo
import
*
from
.pretrain
import
*
apps/life_sci/python/dgllife/model/gnn/__init__.py
deleted
100644 → 0
View file @
94c67203
"""Graph neural networks for updating node representations."""
from
.attentivefp
import
*
from
.gat
import
*
from
.gcn
import
*
from
.mgcn
import
*
from
.mpnn
import
*
from
.schnet
import
*
from
.wln
import
*
from
.weave
import
*
from
.gin
import
*
apps/life_sci/python/dgllife/model/gnn/attentivefp.py
deleted
100644 → 0
View file @
94c67203
"""AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl.nn.pytorch
import
edge_softmax
__all__
=
[
'AttentiveFPGNN'
]
# pylint: disable=W0221, C0103, E1101
class
AttentiveGRU1
(
nn
.
Module
):
"""Update node features with attention and GRU.
This will be used for incorporating the information of edge features
into node features for message passing.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
edge_hidden_size
,
dropout
):
super
(
AttentiveGRU1
,
self
).
__init__
()
self
.
edge_transform
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
edge_feat_size
,
edge_hidden_size
)
)
self
.
gru
=
nn
.
GRUCell
(
edge_hidden_size
,
node_feat_size
)
def
forward
(
self
,
g
,
edge_logits
,
edge_feats
,
node_feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Previous edge features.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
edata
[
'e'
]
=
edge_softmax
(
g
,
edge_logits
)
*
self
.
edge_transform
(
edge_feats
)
g
.
update_all
(
fn
.
copy_edge
(
'e'
,
'm'
),
fn
.
sum
(
'm'
,
'c'
))
context
=
F
.
elu
(
g
.
ndata
[
'c'
])
return
F
.
relu
(
self
.
gru
(
context
,
node_feats
))
class
AttentiveGRU2
(
nn
.
Module
):
"""Update node features with attention and GRU.
This will be used in GNN layers for updating node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_hidden_size
,
dropout
):
super
(
AttentiveGRU2
,
self
).
__init__
()
self
.
project_node
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
node_feat_size
,
edge_hidden_size
)
)
self
.
gru
=
nn
.
GRUCell
(
edge_hidden_size
,
node_feat_size
)
def
forward
(
self
,
g
,
edge_logits
,
node_feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
edata
[
'a'
]
=
edge_softmax
(
g
,
edge_logits
)
g
.
ndata
[
'hv'
]
=
self
.
project_node
(
node_feats
)
g
.
update_all
(
fn
.
src_mul_edge
(
'hv'
,
'a'
,
'm'
),
fn
.
sum
(
'm'
,
'c'
))
context
=
F
.
elu
(
g
.
ndata
[
'c'
])
return
F
.
relu
(
self
.
gru
(
context
,
node_feats
))
class
GetContext
(
nn
.
Module
):
"""Generate context for each node by message passing at the beginning.
This layer incorporates the information of edge features into node
representations so that message passing needs to be only performed over
node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
graph_feat_size
,
dropout
):
super
(
GetContext
,
self
).
__init__
()
self
.
project_node
=
nn
.
Sequential
(
nn
.
Linear
(
node_feat_size
,
graph_feat_size
),
nn
.
LeakyReLU
()
)
self
.
project_edge1
=
nn
.
Sequential
(
nn
.
Linear
(
node_feat_size
+
edge_feat_size
,
graph_feat_size
),
nn
.
LeakyReLU
()
)
self
.
project_edge2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
2
*
graph_feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
attentive_gru
=
AttentiveGRU1
(
graph_feat_size
,
graph_feat_size
,
graph_feat_size
,
dropout
)
def
apply_edges1
(
self
,
edges
):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he1'`` to updated edge features.
"""
return
{
'he1'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
apply_edges2
(
self
,
edges
):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he2'`` to updated edge features.
"""
return
{
'he2'
:
torch
.
cat
([
edges
.
dst
[
'hv_new'
],
edges
.
data
[
'he1'
]],
dim
=
1
)}
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Incorporate edge features and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
node_feats
g
.
ndata
[
'hv_new'
]
=
self
.
project_node
(
node_feats
)
g
.
edata
[
'he'
]
=
edge_feats
g
.
apply_edges
(
self
.
apply_edges1
)
g
.
edata
[
'he1'
]
=
self
.
project_edge1
(
g
.
edata
[
'he1'
])
g
.
apply_edges
(
self
.
apply_edges2
)
logits
=
self
.
project_edge2
(
g
.
edata
[
'he2'
])
return
self
.
attentive_gru
(
g
,
logits
,
g
.
edata
[
'he1'
],
g
.
ndata
[
'hv_new'
])
class
GNNLayer
(
nn
.
Module
):
"""GNNLayer for updating node features.
This layer performs message passing over node representations and update them.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the graph representations to be computed.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
graph_feat_size
,
dropout
):
super
(
GNNLayer
,
self
).
__init__
()
self
.
project_edge
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
2
*
node_feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
attentive_gru
=
AttentiveGRU2
(
node_feat_size
,
graph_feat_size
,
dropout
)
def
apply_edges
(
self
,
edges
):
"""Edge feature generation.
Generate edge features by concatenating the features of the destination
and source nodes.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping ``'he'`` to the generated edge features.
"""
return
{
'he'
:
torch
.
cat
([
edges
.
dst
[
'hv'
],
edges
.
src
[
'hv'
]],
dim
=
1
)}
def
forward
(
self
,
g
,
node_feats
):
"""Perform message passing and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
node_feats
g
.
apply_edges
(
self
.
apply_edges
)
logits
=
self
.
project_edge
(
g
.
edata
[
'he'
])
return
self
.
attentive_gru
(
g
,
logits
,
node_feats
)
class
AttentiveFPGNN
(
nn
.
Module
):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class performs message passing in AttentiveFP and returns the updated node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
graph_feat_size : int
Size for the graph representations to be computed. Default to 200.
dropout : float
The probability for performing dropout. Default to 0.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
num_layers
=
2
,
graph_feat_size
=
200
,
dropout
=
0.
):
super
(
AttentiveFPGNN
,
self
).
__init__
()
self
.
init_context
=
GetContext
(
node_feat_size
,
edge_feat_size
,
graph_feat_size
,
dropout
)
self
.
gnn_layers
=
nn
.
ModuleList
()
for
_
in
range
(
num_layers
-
1
):
self
.
gnn_layers
.
append
(
GNNLayer
(
graph_feat_size
,
graph_feat_size
,
dropout
))
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, graph_feat_size)
Updated node representations.
"""
node_feats
=
self
.
init_context
(
g
,
node_feats
,
edge_feats
)
for
gnn
in
self
.
gnn_layers
:
node_feats
=
gnn
(
g
,
node_feats
)
return
node_feats
apps/life_sci/python/dgllife/model/gnn/gat.py
deleted
100644 → 0
View file @
94c67203
"""Graph Attention Networks"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl.nn.pytorch
import
GATConv
__all__
=
[
'GAT'
]
# pylint: disable=W0221
class
GATLayer
(
nn
.
Module
):
r
"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
out_feats : int
Number of output node features
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, which is the slope for negative values.
Default to 0.2.
residual : bool
Whether to perform skip connection, default to True.
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all-head results or 'mean' for averaging
all head results.
activation : activation function or None
Activation function applied to the aggregated multi-head results, default to None.
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
num_heads
,
feat_drop
,
attn_drop
,
alpha
=
0.2
,
residual
=
True
,
agg_mode
=
'flatten'
,
activation
=
None
):
super
(
GATLayer
,
self
).
__init__
()
self
.
gat_conv
=
GATConv
(
in_feats
=
in_feats
,
out_feats
=
out_feats
,
num_heads
=
num_heads
,
feat_drop
=
feat_drop
,
attn_drop
=
attn_drop
,
negative_slope
=
alpha
,
residual
=
residual
)
assert
agg_mode
in
[
'flatten'
,
'mean'
]
self
.
agg_mode
=
agg_mode
self
.
activation
=
activation
def
forward
(
self
,
bg
,
feats
):
"""Update node representations
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
out_feats in initialization if self.agg_mode == 'mean' and
out_feats * num_heads in initialization otherwise.
"""
feats
=
self
.
gat_conv
(
bg
,
feats
)
if
self
.
agg_mode
==
'flatten'
:
feats
=
feats
.
flatten
(
1
)
else
:
feats
=
feats
.
mean
(
1
)
if
self
.
activation
is
not
None
:
feats
=
self
.
activation
(
feats
)
return
feats
class
GAT
(
nn
.
Module
):
r
"""GAT from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
all-head results for each GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, no activation is applied for each GAT layer.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
=
None
,
num_heads
=
None
,
feat_drops
=
None
,
attn_drops
=
None
,
alphas
=
None
,
residuals
=
None
,
agg_modes
=
None
,
activations
=
None
):
super
(
GAT
,
self
).
__init__
()
if
hidden_feats
is
None
:
hidden_feats
=
[
32
,
32
]
n_layers
=
len
(
hidden_feats
)
if
num_heads
is
None
:
num_heads
=
[
4
for
_
in
range
(
n_layers
)]
if
feat_drops
is
None
:
feat_drops
=
[
0.
for
_
in
range
(
n_layers
)]
if
attn_drops
is
None
:
attn_drops
=
[
0.
for
_
in
range
(
n_layers
)]
if
alphas
is
None
:
alphas
=
[
0.2
for
_
in
range
(
n_layers
)]
if
residuals
is
None
:
residuals
=
[
True
for
_
in
range
(
n_layers
)]
if
agg_modes
is
None
:
agg_modes
=
[
'flatten'
for
_
in
range
(
n_layers
-
1
)]
agg_modes
.
append
(
'mean'
)
if
activations
is
None
:
activations
=
[
F
.
elu
for
_
in
range
(
n_layers
-
1
)]
activations
.
append
(
None
)
lengths
=
[
len
(
hidden_feats
),
len
(
num_heads
),
len
(
feat_drops
),
len
(
attn_drops
),
len
(
alphas
),
len
(
residuals
),
len
(
agg_modes
),
len
(
activations
)]
assert
len
(
set
(
lengths
))
==
1
,
'Expect the lengths of hidden_feats, num_heads, '
\
'feat_drops, attn_drops, alphas, residuals, '
\
'agg_modes and activations to be the same, '
\
'got {}'
.
format
(
lengths
)
self
.
hidden_feats
=
hidden_feats
self
.
num_heads
=
num_heads
self
.
agg_modes
=
agg_modes
self
.
gnn_layers
=
nn
.
ModuleList
()
for
i
in
range
(
n_layers
):
self
.
gnn_layers
.
append
(
GATLayer
(
in_feats
,
hidden_feats
[
i
],
num_heads
[
i
],
feat_drops
[
i
],
attn_drops
[
i
],
alphas
[
i
],
residuals
[
i
],
agg_modes
[
i
],
activations
[
i
]))
if
agg_modes
[
i
]
==
'flatten'
:
in_feats
=
hidden_feats
[
i
]
*
num_heads
[
i
]
else
:
in_feats
=
hidden_feats
[
i
]
def
forward
(
self
,
g
,
feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] if agg_modes[-1] == 'mean' and
hidden_sizes[-1] * num_heads[-1] otherwise.
"""
for
gnn
in
self
.
gnn_layers
:
feats
=
gnn
(
g
,
feats
)
return
feats
apps/life_sci/python/dgllife/model/gnn/gcn.py
deleted
100644 → 0
View file @
94c67203
"""Graph Convolutional Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl.nn.pytorch
import
GraphConv
__all__
=
[
'GCN'
]
# pylint: disable=W0221, C0103
class
GCNLayer
(
nn
.
Module
):
r
"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
out_feats : int
Number of output node features.
activation : activation function
Default to be None.
residual : bool
Whether to use residual connection, default to be True.
batchnorm : bool
Whether to use batch normalization on the output,
default to be True.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
=
None
,
residual
=
True
,
batchnorm
=
True
,
dropout
=
0.
):
super
(
GCNLayer
,
self
).
__init__
()
self
.
activation
=
activation
self
.
graph_conv
=
GraphConv
(
in_feats
=
in_feats
,
out_feats
=
out_feats
,
norm
=
'none'
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
residual
=
residual
if
residual
:
self
.
res_connection
=
nn
.
Linear
(
in_feats
,
out_feats
)
self
.
bn
=
batchnorm
if
batchnorm
:
self
.
bn_layer
=
nn
.
BatchNorm1d
(
out_feats
)
def
forward
(
self
,
g
,
feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output node feature size, which must match out_feats in initialization
"""
new_feats
=
self
.
graph_conv
(
g
,
feats
)
if
self
.
residual
:
res_feats
=
self
.
activation
(
self
.
res_connection
(
feats
))
new_feats
=
new_feats
+
res_feats
new_feats
=
self
.
dropout
(
new_feats
)
if
self
.
bn
:
new_feats
=
self
.
bn_layer
(
new_feats
)
return
new_feats
class
GCN
(
nn
.
Module
):
r
"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
=
None
,
activation
=
None
,
residual
=
None
,
batchnorm
=
None
,
dropout
=
None
):
super
(
GCN
,
self
).
__init__
()
if
hidden_feats
is
None
:
hidden_feats
=
[
64
,
64
]
n_layers
=
len
(
hidden_feats
)
if
activation
is
None
:
activation
=
[
F
.
relu
for
_
in
range
(
n_layers
)]
if
residual
is
None
:
residual
=
[
True
for
_
in
range
(
n_layers
)]
if
batchnorm
is
None
:
batchnorm
=
[
True
for
_
in
range
(
n_layers
)]
if
dropout
is
None
:
dropout
=
[
0.
for
_
in
range
(
n_layers
)]
lengths
=
[
len
(
hidden_feats
),
len
(
activation
),
len
(
residual
),
len
(
batchnorm
),
len
(
dropout
)]
assert
len
(
set
(
lengths
))
==
1
,
'Expect the lengths of hidden_feats, activation, '
\
'residual, batchnorm and dropout to be the same, '
\
'got {}'
.
format
(
lengths
)
self
.
hidden_feats
=
hidden_feats
self
.
gnn_layers
=
nn
.
ModuleList
()
for
i
in
range
(
n_layers
):
self
.
gnn_layers
.
append
(
GCNLayer
(
in_feats
,
hidden_feats
[
i
],
activation
[
i
],
residual
[
i
],
batchnorm
[
i
],
dropout
[
i
]))
in_feats
=
hidden_feats
[
i
]
def
forward
(
self
,
g
,
feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] in initialization.
"""
for
gnn
in
self
.
gnn_layers
:
feats
=
gnn
(
g
,
feats
)
return
feats
apps/life_sci/python/dgllife/model/gnn/gin.py
deleted
100644 → 0
View file @
94c67203
"""Graph Isomorphism Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
__all__
=
[
'GIN'
]
# pylint: disable=W0221, C0103
class
GINLayer
(
nn
.
Module
):
r
"""Single Layer GIN from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
Parameters
----------
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
emb_dim : int
The size of each embedding vector.
batch_norm : bool
Whether to apply batch normalization to the output of message passing.
Default to True.
activation : None or callable
Activation function to apply to the output node representations.
Default to None.
"""
def
__init__
(
self
,
num_edge_emb_list
,
emb_dim
,
batch_norm
=
True
,
activation
=
None
):
super
(
GINLayer
,
self
).
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
emb_dim
,
2
*
emb_dim
),
nn
.
ReLU
(),
nn
.
Linear
(
2
*
emb_dim
,
emb_dim
)
)
self
.
edge_embeddings
=
nn
.
ModuleList
()
for
num_emb
in
num_edge_emb_list
:
emb_module
=
nn
.
Embedding
(
num_emb
,
emb_dim
)
nn
.
init
.
xavier_uniform_
(
emb_module
.
weight
.
data
)
self
.
edge_embeddings
.
append
(
emb_module
)
if
batch_norm
:
self
.
bn
=
nn
.
BatchNorm1d
(
emb_dim
)
else
:
self
.
bn
=
None
self
.
activation
=
activation
def
forward
(
self
,
g
,
node_feats
,
categorical_edge_feats
):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : FloatTensor of shape (N, emb_dim)
* Input node features
* N is the total number of nodes in the batch of graphs
* emb_dim is the input node feature size, which must match emb_dim in initialization
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
* E is the total number of edges in the batch of graphs
Returns
-------
node_feats : float32 tensor of shape (N, emb_dim)
Output node representations
"""
edge_embeds
=
[]
for
i
,
feats
in
enumerate
(
categorical_edge_feats
):
edge_embeds
.
append
(
self
.
edge_embeddings
[
i
](
feats
))
edge_embeds
=
torch
.
stack
(
edge_embeds
,
dim
=
0
).
sum
(
0
)
g
=
g
.
local_var
()
g
.
ndata
[
'feat'
]
=
node_feats
g
.
edata
[
'feat'
]
=
edge_embeds
g
.
update_all
(
fn
.
u_add_e
(
'feat'
,
'feat'
,
'm'
),
fn
.
sum
(
'm'
,
'feat'
))
node_feats
=
self
.
mlp
(
g
.
ndata
.
pop
(
'feat'
))
if
self
.
bn
is
not
None
:
node_feats
=
self
.
bn
(
node_feats
)
if
self
.
activation
is
not
None
:
node_feats
=
self
.
activation
(
node_feats
)
return
node_feats
class
GIN
(
nn
.
Module
):
r
"""Graph Isomorphism Network from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
This module is for updating node representations only.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``.
Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5
"""
def
__init__
(
self
,
num_node_emb_list
,
num_edge_emb_list
,
num_layers
=
5
,
emb_dim
=
300
,
JK
=
'last'
,
dropout
=
0.5
):
super
(
GIN
,
self
).
__init__
()
self
.
num_layers
=
num_layers
self
.
JK
=
JK
self
.
dropout
=
nn
.
Dropout
(
dropout
)
if
num_layers
<
2
:
raise
ValueError
(
'Number of GNN layers must be greater '
'than 1, got {:d}'
.
format
(
num_layers
))
self
.
node_embeddings
=
nn
.
ModuleList
()
for
num_emb
in
num_node_emb_list
:
emb_module
=
nn
.
Embedding
(
num_emb
,
emb_dim
)
nn
.
init
.
xavier_uniform_
(
emb_module
.
weight
.
data
)
self
.
node_embeddings
.
append
(
emb_module
)
self
.
gnn_layers
=
nn
.
ModuleList
()
for
layer
in
range
(
num_layers
):
if
layer
==
num_layers
-
1
:
self
.
gnn_layers
.
append
(
GINLayer
(
num_edge_emb_list
,
emb_dim
))
else
:
self
.
gnn_layers
.
append
(
GINLayer
(
num_edge_emb_list
,
emb_dim
,
activation
=
F
.
relu
))
def
forward
(
self
,
g
,
categorical_node_feats
,
categorical_edge_feats
):
"""Update node representations
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(self.node_embeddings)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
final_node_feats : float32 tensor of shape (N, M)
Output node representations, N for the number of nodes and
M for output size. In particular, M will be emb_dim * (num_layers + 1)
if self.JK == 'concat' and emb_dim otherwise.
"""
node_embeds
=
[]
for
i
,
feats
in
enumerate
(
categorical_node_feats
):
node_embeds
.
append
(
self
.
node_embeddings
[
i
](
feats
))
node_embeds
=
torch
.
stack
(
node_embeds
,
dim
=
0
).
sum
(
0
)
all_layer_node_feats
=
[
node_embeds
]
for
layer
in
range
(
self
.
num_layers
):
node_feats
=
self
.
gnn_layers
[
layer
](
g
,
all_layer_node_feats
[
layer
],
categorical_edge_feats
)
node_feats
=
self
.
dropout
(
node_feats
)
all_layer_node_feats
.
append
(
node_feats
)
if
self
.
JK
==
'concat'
:
final_node_feats
=
torch
.
cat
(
all_layer_node_feats
,
dim
=
1
)
elif
self
.
JK
==
'last'
:
final_node_feats
=
all_layer_node_feats
[
-
1
]
elif
self
.
JK
==
'max'
:
all_layer_node_feats
=
[
h
.
unsqueeze_
(
0
)
for
h
in
all_layer_node_feats
]
final_node_feats
=
torch
.
max
(
torch
.
cat
(
all_layer_node_feats
,
dim
=
0
),
dim
=
0
)[
0
]
elif
self
.
JK
==
'sum'
:
all_layer_node_feats
=
[
h
.
unsqueeze_
(
0
)
for
h
in
all_layer_node_feats
]
final_node_feats
=
torch
.
sum
(
torch
.
cat
(
all_layer_node_feats
,
dim
=
0
),
dim
=
0
)
else
:
return
ValueError
(
"Expect self.JK to be 'concat', 'last', "
"'max' or 'sum', got {}"
.
format
(
self
.
JK
))
return
final_node_feats
apps/life_sci/python/dgllife/model/gnn/mgcn.py
deleted
100644 → 0
View file @
94c67203
"""MGCN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
from
.schnet
import
RBFExpansion
__all__
=
[
'MGCNGNN'
]
# pylint: disable=W0221, E1101
class
EdgeEmbedding
(
nn
.
Module
):
"""Module for embedding edges.
Edges whose end nodes have the same combination of types
share the same initial embedding.
Parameters
----------
num_types : int
Number of edge types to embed.
edge_feats : int
Size for the edge representations to learn.
"""
def
__init__
(
self
,
num_types
,
edge_feats
):
super
(
EdgeEmbedding
,
self
).
__init__
()
self
.
embed
=
nn
.
Embedding
(
num_types
,
edge_feats
)
def
get_edge_types
(
self
,
edges
):
"""Generates edge types.
The edge type is based on the type of the source and destination nodes.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map each pair of node types to a unique number, we use an unordered pairing function.
See more details in this discussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that the number of edge types should be larger than the square of the maximum node
type in the dataset.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping 'type' to the computed edge types.
"""
node_type1
=
edges
.
src
[
'type'
]
node_type2
=
edges
.
dst
[
'type'
]
return
{
'type'
:
node_type1
*
node_type2
+
\
(
torch
.
abs
(
node_type1
-
node_type2
)
-
1
)
**
2
/
4
}
def
forward
(
self
,
g
,
node_types
):
"""Embeds edge types.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
Returns
-------
float32 tensor of shape (E, edge_feats)
Edge representations.
"""
g
=
g
.
local_var
()
g
.
ndata
[
'type'
]
=
node_types
g
.
apply_edges
(
self
.
get_edge_types
)
return
self
.
embed
(
g
.
edata
[
'type'
])
class
VEConv
(
nn
.
Module
):
"""Vertex-Edge Convolution in MGCN
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This layer combines both node and edge features in updating node representations.
Parameters
----------
dist_feats : int
Size for the expanded distances.
feats : int
Size for the input and output node and edge representations.
update_edge : bool
Whether to update edge representations. Default to True.
"""
def
__init__
(
self
,
dist_feats
,
feats
,
update_edge
=
True
):
super
(
VEConv
,
self
).
__init__
()
self
.
update_dists
=
nn
.
Sequential
(
nn
.
Linear
(
dist_feats
,
feats
),
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
),
nn
.
Linear
(
feats
,
feats
)
)
if
update_edge
:
self
.
update_edge_feats
=
nn
.
Linear
(
feats
,
feats
)
else
:
self
.
update_edge_feats
=
None
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
expanded_dists
):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features.
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Edge representations, updated if ``update_edge == True`` in initialization.
"""
expanded_dists
=
self
.
update_dists
(
expanded_dists
)
if
self
.
update_edge_feats
is
not
None
:
edge_feats
=
self
.
update_edge_feats
(
edge_feats
)
g
=
g
.
local_var
()
g
.
ndata
.
update
({
'hv'
:
node_feats
})
g
.
edata
.
update
({
'dist'
:
expanded_dists
,
'he'
:
edge_feats
})
g
.
update_all
(
message_func
=
[
fn
.
u_mul_e
(
'hv'
,
'dist'
,
'm_0'
),
fn
.
copy_e
(
'he'
,
'm_1'
)],
reduce_func
=
[
fn
.
sum
(
'm_0'
,
'hv_0'
),
fn
.
sum
(
'm_1'
,
'hv_1'
)])
node_feats
=
g
.
ndata
.
pop
(
'hv_0'
)
+
g
.
ndata
.
pop
(
'hv_1'
)
return
node_feats
,
edge_feats
class
MultiLevelInteraction
(
nn
.
Module
):
"""Building block for MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__. This layer combines node features,
edge features and expanded distances in message passing and updates node and edge
representations.
Parameters
----------
feats : int
Size for the input and output node and edge representations.
dist_feats : int
Size for the expanded distances.
"""
def
__init__
(
self
,
feats
,
dist_feats
):
super
(
MultiLevelInteraction
,
self
).
__init__
()
self
.
project_in_node_feats
=
nn
.
Linear
(
feats
,
feats
)
self
.
conv
=
VEConv
(
dist_feats
,
feats
)
self
.
project_out_node_feats
=
nn
.
Sequential
(
nn
.
Linear
(
feats
,
feats
),
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
),
nn
.
Linear
(
feats
,
feats
)
)
self
.
project_edge_feats
=
nn
.
Sequential
(
nn
.
Linear
(
feats
,
feats
),
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
)
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
expanded_dists
):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Updated edge representations.
"""
new_node_feats
=
self
.
project_in_node_feats
(
node_feats
)
new_node_feats
,
edge_feats
=
self
.
conv
(
g
,
new_node_feats
,
edge_feats
,
expanded_dists
)
new_node_feats
=
self
.
project_out_node_feats
(
new_node_feats
)
node_feats
=
node_feats
+
new_node_feats
edge_feats
=
self
.
project_edge_feats
(
edge_feats
)
return
node_feats
,
edge_feats
class
MGCNGNN
(
nn
.
Module
):
"""MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This class performs message passing in MGCN and returns the updated node representations.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def
__init__
(
self
,
feats
=
128
,
n_layers
=
3
,
num_node_types
=
100
,
num_edge_types
=
3000
,
cutoff
=
30.
,
gap
=
0.1
):
super
(
MGCNGNN
,
self
).
__init__
()
self
.
node_embed
=
nn
.
Embedding
(
num_node_types
,
feats
)
self
.
edge_embed
=
EdgeEmbedding
(
num_edge_types
,
feats
)
self
.
rbf
=
RBFExpansion
(
high
=
cutoff
,
gap
=
gap
)
self
.
gnn_layers
=
nn
.
ModuleList
()
for
_
in
range
(
n_layers
):
self
.
gnn_layers
.
append
(
MultiLevelInteraction
(
feats
,
len
(
self
.
rbf
.
centers
)))
def
forward
(
self
,
g
,
node_types
,
edge_dists
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (V, feats * (n_layers + 1))
Output node representations.
"""
node_feats
=
self
.
node_embed
(
node_types
)
edge_feats
=
self
.
edge_embed
(
g
,
node_types
)
expanded_dists
=
self
.
rbf
(
edge_dists
)
all_layer_node_feats
=
[
node_feats
]
for
gnn
in
self
.
gnn_layers
:
node_feats
,
edge_feats
=
gnn
(
g
,
node_feats
,
edge_feats
,
expanded_dists
)
all_layer_node_feats
.
append
(
node_feats
)
return
torch
.
cat
(
all_layer_node_feats
,
dim
=
1
)
apps/life_sci/python/dgllife/model/gnn/mpnn.py
deleted
100644 → 0
View file @
94c67203
"""MPNN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl.nn.pytorch
import
NNConv
__all__
=
[
'MPNNGNN'
]
# pylint: disable=W0221
class
MPNNGNN
(
nn
.
Module
):
"""MPNN.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
This class performs message passing in MPNN and returns the updated node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_in_feats : int
Size for the input edge features. Default to 128.
edge_hidden_feats : int
Size for the hidden edge representations.
num_step_message_passing : int
Number of message passing steps. Default to 6.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_out_feats
=
64
,
edge_hidden_feats
=
128
,
num_step_message_passing
=
6
):
super
(
MPNNGNN
,
self
).
__init__
()
self
.
project_node_feats
=
nn
.
Sequential
(
nn
.
Linear
(
node_in_feats
,
node_out_feats
),
nn
.
ReLU
()
)
self
.
num_step_message_passing
=
num_step_message_passing
edge_network
=
nn
.
Sequential
(
nn
.
Linear
(
edge_in_feats
,
edge_hidden_feats
),
nn
.
ReLU
(),
nn
.
Linear
(
edge_hidden_feats
,
node_out_feats
*
node_out_feats
)
)
self
.
gnn_layer
=
NNConv
(
in_feats
=
node_out_feats
,
out_feats
=
node_out_feats
,
edge_func
=
edge_network
,
aggregator_type
=
'sum'
)
self
.
gru
=
nn
.
GRU
(
node_out_feats
,
node_out_feats
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
Returns
-------
node_feats : float32 tensor of shape (V, node_out_feats)
Output node representations.
"""
node_feats
=
self
.
project_node_feats
(
node_feats
)
# (V, node_out_feats)
hidden_feats
=
node_feats
.
unsqueeze
(
0
)
# (1, V, node_out_feats)
for
_
in
range
(
self
.
num_step_message_passing
):
node_feats
=
F
.
relu
(
self
.
gnn_layer
(
g
,
node_feats
,
edge_feats
))
node_feats
,
hidden_feats
=
self
.
gru
(
node_feats
.
unsqueeze
(
0
),
hidden_feats
)
node_feats
=
node_feats
.
squeeze
(
0
)
return
node_feats
apps/life_sci/python/dgllife/model/gnn/schnet.py
deleted
100644 → 0
View file @
94c67203
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621, W0221, E1102, E1101
"""SchNet"""
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
dgl.nn.pytorch
import
CFConv
__all__
=
[
'SchNetGNN'
]
class
RBFExpansion
(
nn
.
Module
):
r
"""Expand distances between nodes by radial basis functions.
.. math::
\exp(- \gamma * ||d - \mu||^2)
where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
the distances. We use multiple centers evenly distributed in the range of
:math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
being :math:`gap`.
The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
Choosing fewer centers corresponds to reducing the resolution of the filter.
Parameters
----------
low : float
Smallest center. Default to 0.
high : float
Largest center. Default to 30.
gap : float
Difference between two adjacent centers. :math:`\gamma` will be computed as the
reciprocal of gap. Default to 0.1.
"""
def
__init__
(
self
,
low
=
0.
,
high
=
30.
,
gap
=
0.1
):
super
(
RBFExpansion
,
self
).
__init__
()
num_centers
=
int
(
np
.
ceil
((
high
-
low
)
/
gap
))
centers
=
np
.
linspace
(
low
,
high
,
num_centers
)
self
.
centers
=
nn
.
Parameter
(
torch
.
tensor
(
centers
).
float
(),
requires_grad
=
False
)
self
.
gamma
=
1
/
gap
def
forward
(
self
,
edge_dists
):
"""Expand distances.
Parameters
----------
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (E, len(self.centers))
Expanded distances.
"""
radial
=
edge_dists
-
self
.
centers
coef
=
-
self
.
gamma
return
torch
.
exp
(
coef
*
(
radial
**
2
))
class
Interaction
(
nn
.
Module
):
"""Building block for SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This layer combines node and edge features in message passing and updates node
representations.
Parameters
----------
node_feats : int
Size for the input and output node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for hidden representations.
"""
def
__init__
(
self
,
node_feats
,
edge_in_feats
,
hidden_feats
):
super
(
Interaction
,
self
).
__init__
()
self
.
conv
=
CFConv
(
node_feats
,
edge_in_feats
,
hidden_feats
,
node_feats
)
self
.
project_out
=
nn
.
Linear
(
node_feats
,
node_feats
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats
=
self
.
conv
(
g
,
node_feats
,
edge_feats
)
return
self
.
project_out
(
node_feats
)
class
SchNetGNN
(
nn
.
Module
):
"""SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This class performs message passing in SchNet and returns the updated node representations.
Parameters
----------
node_feats : int
Size for node representations to learn. Default to 64.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
layer. ``len(hidden_feats)`` equals the number of interaction layers.
Default to ``[64, 64, 64]``.
num_node_types : int
Number of node types to embed. Default to 100.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def
__init__
(
self
,
node_feats
=
64
,
hidden_feats
=
None
,
num_node_types
=
100
,
cutoff
=
30.
,
gap
=
0.1
):
super
(
SchNetGNN
,
self
).
__init__
()
if
hidden_feats
is
None
:
hidden_feats
=
[
64
,
64
,
64
]
self
.
embed
=
nn
.
Embedding
(
num_node_types
,
node_feats
)
self
.
rbf
=
RBFExpansion
(
high
=
cutoff
,
gap
=
gap
)
n_layers
=
len
(
hidden_feats
)
self
.
gnn_layers
=
nn
.
ModuleList
()
for
i
in
range
(
n_layers
):
self
.
gnn_layers
.
append
(
Interaction
(
node_feats
,
len
(
self
.
rbf
.
centers
),
hidden_feats
[
i
]))
def
forward
(
self
,
g
,
node_types
,
edge_dists
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats
=
self
.
embed
(
node_types
)
expanded_dists
=
self
.
rbf
(
edge_dists
)
for
gnn
in
self
.
gnn_layers
:
node_feats
=
gnn
(
g
,
node_feats
,
expanded_dists
)
return
node_feats
apps/life_sci/python/dgllife/model/gnn/weave.py
deleted
100644 → 0
View file @
94c67203
"""Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
__all__
=
[
'WeaveGNN'
]
# pylint: disable=W0221, E1101
class
WeaveLayer
(
nn
.
Module
):
r
"""Single Weave layer from `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_node_hidden_feats : int
Size for the hidden node representations in updating node representations.
Default to 50.
edge_node_hidden_feats : int
Size for the hidden edge representations in updating node representations.
Default to 50.
node_out_feats : int
Size for the output node representations. Default to 50.
node_edge_hidden_feats : int
Size for the hidden node representations in updating edge representations.
Default to 50.
edge_edge_hidden_feats : int
Size for the hidden edge representations in updating edge representations.
Default to 50.
edge_out_feats : int
Size for the output edge representations. Default to 50.
activation : callable
Activation function to apply. Default to ReLU.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_node_hidden_feats
=
50
,
edge_node_hidden_feats
=
50
,
node_out_feats
=
50
,
node_edge_hidden_feats
=
50
,
edge_edge_hidden_feats
=
50
,
edge_out_feats
=
50
,
activation
=
F
.
relu
):
super
(
WeaveLayer
,
self
).
__init__
()
self
.
activation
=
activation
# Layers for updating node representations
self
.
node_to_node
=
nn
.
Linear
(
node_in_feats
,
node_node_hidden_feats
)
self
.
edge_to_node
=
nn
.
Linear
(
edge_in_feats
,
edge_node_hidden_feats
)
self
.
update_node
=
nn
.
Linear
(
node_node_hidden_feats
+
edge_node_hidden_feats
,
node_out_feats
)
# Layers for updating edge representations
self
.
left_node_to_edge
=
nn
.
Linear
(
node_in_feats
,
node_edge_hidden_feats
)
self
.
right_node_to_edge
=
nn
.
Linear
(
node_in_feats
,
node_edge_hidden_feats
)
self
.
edge_to_edge
=
nn
.
Linear
(
edge_in_feats
,
edge_edge_hidden_feats
)
self
.
update_edge
=
nn
.
Linear
(
2
*
node_edge_hidden_feats
+
edge_edge_hidden_feats
,
edge_out_feats
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
node_only
=
False
):
r
"""Update node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to update node representations only. If False, edge representations
will be updated as well. Default to False.
Returns
-------
new_node_feats : float32 tensor of shape (V, node_out_feats)
Updated node representations.
new_edge_feats : float32 tensor of shape (E, edge_out_feats)
Updated edge representations.
"""
g
=
g
.
local_var
()
# Update node features
node_node_feats
=
self
.
activation
(
self
.
node_to_node
(
node_feats
))
g
.
edata
[
'e2n'
]
=
self
.
activation
(
self
.
edge_to_node
(
edge_feats
))
g
.
update_all
(
fn
.
copy_edge
(
'e2n'
,
'm'
),
fn
.
sum
(
'm'
,
'e2n'
))
edge_node_feats
=
g
.
ndata
.
pop
(
'e2n'
)
new_node_feats
=
self
.
activation
(
self
.
update_node
(
torch
.
cat
([
node_node_feats
,
edge_node_feats
],
dim
=
1
)))
if
node_only
:
return
new_node_feats
# Update edge features
g
.
ndata
[
'left_hv'
]
=
self
.
left_node_to_edge
(
node_feats
)
g
.
ndata
[
'right_hv'
]
=
self
.
right_node_to_edge
(
node_feats
)
g
.
apply_edges
(
fn
.
u_add_v
(
'left_hv'
,
'right_hv'
,
'first'
))
g
.
apply_edges
(
fn
.
u_add_v
(
'right_hv'
,
'left_hv'
,
'second'
))
first_edge_feats
=
self
.
activation
(
g
.
edata
.
pop
(
'first'
))
second_edge_feats
=
self
.
activation
(
g
.
edata
.
pop
(
'second'
))
third_edge_feats
=
self
.
activation
(
self
.
edge_to_edge
(
edge_feats
))
new_edge_feats
=
self
.
activation
(
self
.
update_edge
(
torch
.
cat
([
first_edge_feats
,
second_edge_feats
,
third_edge_feats
],
dim
=
1
)))
return
new_node_feats
,
new_edge_feats
class
WeaveGNN
(
nn
.
Module
):
r
"""The component of Weave for updating node and edge representations.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_layers : int
Number of Weave layers to use, which is equivalent to the times of message passing.
Default to 2.
hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
activation : callable
Activation function to be used. It cannot be None. Default to ReLU.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
num_layers
=
2
,
hidden_feats
=
50
,
activation
=
F
.
relu
):
super
(
WeaveGNN
,
self
).
__init__
()
self
.
gnn_layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_layers
):
if
i
==
0
:
self
.
gnn_layers
.
append
(
WeaveLayer
(
node_in_feats
=
node_in_feats
,
edge_in_feats
=
edge_in_feats
,
node_node_hidden_feats
=
hidden_feats
,
edge_node_hidden_feats
=
hidden_feats
,
node_out_feats
=
hidden_feats
,
node_edge_hidden_feats
=
hidden_feats
,
edge_edge_hidden_feats
=
hidden_feats
,
edge_out_feats
=
hidden_feats
,
activation
=
activation
))
else
:
self
.
gnn_layers
.
append
(
WeaveLayer
(
node_in_feats
=
hidden_feats
,
edge_in_feats
=
hidden_feats
,
node_node_hidden_feats
=
hidden_feats
,
edge_node_hidden_feats
=
hidden_feats
,
node_out_feats
=
hidden_feats
,
node_edge_hidden_feats
=
hidden_feats
,
edge_edge_hidden_feats
=
hidden_feats
,
edge_out_feats
=
hidden_feats
,
activation
=
activation
))
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
node_only
=
True
):
"""Updates node representations (and edge representations).
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to return updated node representations only or to return both
node and edge representations. Default to True.
Returns
-------
float32 tensor of shape (V, gnn_hidden_feats)
Updated node representations.
float32 tensor of shape (E, gnn_hidden_feats), optional
This is returned only when ``node_only==False``. Updated edge representations.
"""
for
i
in
range
(
len
(
self
.
gnn_layers
)
-
1
):
node_feats
,
edge_feats
=
self
.
gnn_layers
[
i
](
g
,
node_feats
,
edge_feats
)
return
self
.
gnn_layers
[
-
1
](
g
,
node_feats
,
edge_feats
,
node_only
)
apps/life_sci/python/dgllife/model/gnn/wln.py
deleted
100644 → 0
View file @
94c67203
"""WLN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
math
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
__all__
=
[
'WLN'
]
class
WLNLinear
(
nn
.
Module
):
r
"""Linear layer for WLN
Let stddev be
.. math::
\min(\frac{1.0}{\sqrt{in_feats}}, 0.1)
The weight of the linear layer is initialized from a normal distribution
with mean 0 and std as specified in stddev.
Parameters
----------
in_feats : int
Size for the input.
out_feats : int
Size for the output.
bias : bool
Whether bias will be added to the output. Default to True.
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
bias
=
True
):
super
(
WLNLinear
,
self
).
__init__
()
self
.
in_feats
=
in_feats
self
.
out_feats
=
out_feats
self
.
weight
=
Parameter
(
torch
.
Tensor
(
out_feats
,
in_feats
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
out_feats
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
"""Initialize model parameters."""
stddev
=
min
(
1.0
/
math
.
sqrt
(
self
.
in_feats
),
0.1
)
nn
.
init
.
normal_
(
self
.
weight
,
std
=
stddev
)
if
self
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
bias
,
0.0
)
def
forward
(
self
,
feats
):
"""Applies the layer.
Parameters
----------
feats : float32 tensor of shape (N, *, in_feats)
N for the number of samples, * for any additional dimensions.
Returns
-------
float32 tensor of shape (N, *, out_feats)
Result of the layer.
"""
return
F
.
linear
(
feats
,
self
.
weight
,
self
.
bias
)
def
extra_repr
(
self
):
"""Return a description of the layer."""
return
'in_feats={}, out_feats={}, bias={}'
.
format
(
self
.
in_feats
,
self
.
out_feats
,
self
.
bias
is
not
None
)
class
WLN
(
nn
.
Module
):
"""Weisfeiler-Lehman Network (WLN)
WLN is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
This class performs message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
project_in_feats : bool
Whether to project input node features. If this is False, we expect node_in_feats
to be the same as node_out_feats. Default to True.
set_comparison : bool
Whether to perform final node representation update mimicking
set comparison. Default to True.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_out_feats
=
300
,
n_layers
=
3
,
project_in_feats
=
True
,
set_comparison
=
True
):
super
(
WLN
,
self
).
__init__
()
self
.
n_layers
=
n_layers
self
.
project_in_feats
=
project_in_feats
if
project_in_feats
:
self
.
project_node_in_feats
=
nn
.
Sequential
(
WLNLinear
(
node_in_feats
,
node_out_feats
,
bias
=
False
),
nn
.
ReLU
()
)
else
:
assert
node_in_feats
==
node_out_feats
,
\
'Expect input node features to have the same size as that of output '
\
'node features, got {:d} and {:d}'
.
format
(
node_in_feats
,
node_out_feats
)
self
.
project_concatenated_messages
=
nn
.
Sequential
(
WLNLinear
(
edge_in_feats
+
node_out_feats
,
node_out_feats
),
nn
.
ReLU
()
)
self
.
get_new_node_feats
=
nn
.
Sequential
(
WLNLinear
(
2
*
node_out_feats
,
node_out_feats
),
nn
.
ReLU
()
)
self
.
set_comparison
=
set_comparison
if
set_comparison
:
self
.
project_edge_messages
=
WLNLinear
(
edge_in_feats
,
node_out_feats
,
bias
=
False
)
self
.
project_node_messages
=
WLNLinear
(
node_out_feats
,
node_out_feats
,
bias
=
False
)
self
.
project_self
=
WLNLinear
(
node_out_feats
,
node_out_feats
,
bias
=
False
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_out_feats)
Updated node representations.
"""
if
self
.
project_in_feats
:
node_feats
=
self
.
project_node_in_feats
(
node_feats
)
for
_
in
range
(
self
.
n_layers
):
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
node_feats
g
.
apply_edges
(
fn
.
copy_src
(
'hv'
,
'he_src'
))
concat_edge_feats
=
torch
.
cat
([
g
.
edata
[
'he_src'
],
edge_feats
],
dim
=
1
)
g
.
edata
[
'he'
]
=
self
.
project_concatenated_messages
(
concat_edge_feats
)
g
.
update_all
(
fn
.
copy_edge
(
'he'
,
'm'
),
fn
.
sum
(
'm'
,
'hv_new'
))
node_feats
=
self
.
get_new_node_feats
(
torch
.
cat
([
node_feats
,
g
.
ndata
[
'hv_new'
]],
dim
=
1
))
if
not
self
.
set_comparison
:
return
node_feats
else
:
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
self
.
project_node_messages
(
node_feats
)
g
.
edata
[
'he'
]
=
self
.
project_edge_messages
(
edge_feats
)
g
.
update_all
(
fn
.
u_mul_e
(
'hv'
,
'he'
,
'm'
),
fn
.
sum
(
'm'
,
'h_nbr'
))
h_self
=
self
.
project_self
(
node_feats
)
# (V, node_out_feats)
return
g
.
ndata
[
'h_nbr'
]
*
h_self
apps/life_sci/python/dgllife/model/model_zoo/__init__.py
deleted
100644 → 0
View file @
94c67203
"""Collection of model architectures"""
from
.jtnn
import
*
from
.dgmg
import
*
from
.attentivefp_predictor
import
*
from
.gat_predictor
import
*
from
.gcn_predictor
import
*
from
.mlp_predictor
import
*
from
.schnet_predictor
import
*
from
.mgcn_predictor
import
*
from
.mpnn_predictor
import
*
from
.acnn
import
*
from
.wln_reaction_center
import
*
from
.wln_reaction_ranking
import
*
from
.weave_predictor
import
*
from
.gin_predictor
import
*
apps/life_sci/python/dgllife/model/model_zoo/acnn.py
deleted
100644 → 0
View file @
94c67203
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123, W0221, E1101, R1721
import
itertools
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
dgl
import
BatchedDGLHeteroGraph
from
dgl.nn.pytorch
import
AtomicConv
__all__
=
[
'ACNN'
]
def
truncated_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape
=
tensor
.
shape
tmp
=
tensor
.
new_empty
(
shape
+
(
4
,)).
normal_
()
valid
=
(
tmp
<
2
)
&
(
tmp
>
-
2
)
ind
=
valid
.
max
(
-
1
,
keepdim
=
True
)[
1
]
tensor
.
data
.
copy_
(
tmp
.
gather
(
-
1
,
ind
).
squeeze
(
-
1
))
tensor
.
data
.
mul_
(
std
).
add_
(
mean
)
class
ACNNPredictor
(
nn
.
Module
):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def
__init__
(
self
,
in_size
,
hidden_sizes
,
weight_init_stddevs
,
dropouts
,
features_to_use
,
num_tasks
):
super
(
ACNNPredictor
,
self
).
__init__
()
if
type
(
features_to_use
)
!=
type
(
None
):
in_size
*=
len
(
features_to_use
)
modules
=
[]
for
i
,
h
in
enumerate
(
hidden_sizes
):
linear_layer
=
nn
.
Linear
(
in_size
,
h
)
truncated_normal_
(
linear_layer
.
weight
,
std
=
weight_init_stddevs
[
i
])
modules
.
append
(
linear_layer
)
modules
.
append
(
nn
.
ReLU
())
modules
.
append
(
nn
.
Dropout
(
dropouts
[
i
]))
in_size
=
h
linear_layer
=
nn
.
Linear
(
in_size
,
num_tasks
)
truncated_normal_
(
linear_layer
.
weight
,
std
=
weight_init_stddevs
[
-
1
])
modules
.
append
(
linear_layer
)
self
.
project
=
nn
.
Sequential
(
*
modules
)
def
forward
(
self
,
batch_size
,
frag1_node_indices_in_complex
,
frag2_node_indices_in_complex
,
ligand_conv_out
,
protein_conv_out
,
complex_conv_out
):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats
=
self
.
project
(
ligand_conv_out
)
# (V1, O)
protein_feats
=
self
.
project
(
protein_conv_out
)
# (V2, O)
complex_feats
=
self
.
project
(
complex_conv_out
)
# (V1+V2, O)
ligand_energy
=
ligand_feats
.
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
# (B, O)
protein_energy
=
protein_feats
.
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
# (B, O)
complex_ligand_energy
=
complex_feats
[
frag1_node_indices_in_complex
].
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
complex_protein_energy
=
complex_feats
[
frag2_node_indices_in_complex
].
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
complex_energy
=
complex_ligand_energy
+
complex_protein_energy
return
complex_energy
-
(
ligand_energy
+
protein_energy
)
class
ACNN
(
nn
.
Module
):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
The prediction proceeds as follows:
1. Perform message passing to update atom representations for the
ligand, protein and protein-ligand complex.
2. Predict the energy of atoms from their representations with an MLP.
3. Take the sum of predicted energy of atoms within each molecule for
predicted energy of the ligand, protein and protein-ligand complex.
4. Make the final prediction by subtracting the predicted ligand and protein
energy from the predicted complex energy.
Parameters
----------
hidden_sizes : list of int
``hidden_sizes[i]`` gives the size of hidden representations in the i-th
hidden layer of the MLP. By Default, ``[32, 32, 16]`` will be used.
weight_init_stddevs : list of float
``weight_init_stddevs[i]`` gives the std to initialize parameters in the
i-th layer of the MLP. Note that ``len(weight_init_stddevs) == len(hidden_sizes) + 1``
due to the output layer. By default, we use ``1 / sqrt(hidden_sizes[i])`` for hidden
layers and 0.01 for the output layer.
dropouts : list of float
``dropouts[i]`` gives the dropout in the i-th hidden layer of the MLP. By default,
no dropout is used.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. If None, we use same parameters
for all atoms regardless of their type. Default to None.
radial : list
The list consists of 3 sublists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. By default,
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks. Default to 1.
"""
def
__init__
(
self
,
hidden_sizes
=
None
,
weight_init_stddevs
=
None
,
dropouts
=
None
,
features_to_use
=
None
,
radial
=
None
,
num_tasks
=
1
):
super
(
ACNN
,
self
).
__init__
()
if
hidden_sizes
is
None
:
hidden_sizes
=
[
32
,
32
,
16
]
if
weight_init_stddevs
is
None
:
weight_init_stddevs
=
[
1.
/
float
(
np
.
sqrt
(
hidden_sizes
[
i
]))
for
i
in
range
(
len
(
hidden_sizes
))]
weight_init_stddevs
.
append
(
0.01
)
if
dropouts
is
None
:
dropouts
=
[
0.
for
_
in
range
(
len
(
hidden_sizes
))]
if
radial
is
None
:
radial
=
[[
12.0
],
[
0.0
,
2.0
,
4.0
,
6.0
,
8.0
],
[
4.0
]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params
=
[
x
for
x
in
itertools
.
product
(
*
radial
)]
radial_params
=
torch
.
stack
(
list
(
map
(
torch
.
tensor
,
zip
(
*
radial_params
))),
dim
=
1
)
interaction_cutoffs
=
radial_params
[:,
0
]
rbf_kernel_means
=
radial_params
[:,
1
]
rbf_kernel_scaling
=
radial_params
[:,
2
]
self
.
ligand_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
protein_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
complex_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
predictor
=
ACNNPredictor
(
radial_params
.
shape
[
0
],
hidden_sizes
,
weight_init_stddevs
,
dropouts
,
features_to_use
,
num_tasks
)
def
forward
(
self
,
graph
):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features. For a batch of
protein-ligand pairs, we assume zero padding is performed so that the
number of ligand and protein atoms is the same in all pairs.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph
=
graph
[(
'ligand_atom'
,
'ligand'
,
'ligand_atom'
)]
ligand_graph_node_feats
=
ligand_graph
.
ndata
[
'atomic_number'
]
assert
ligand_graph_node_feats
.
shape
[
-
1
]
==
1
ligand_graph_distances
=
ligand_graph
.
edata
[
'distance'
]
ligand_conv_out
=
self
.
ligand_conv
(
ligand_graph
,
ligand_graph_node_feats
,
ligand_graph_distances
)
protein_graph
=
graph
[(
'protein_atom'
,
'protein'
,
'protein_atom'
)]
protein_graph_node_feats
=
protein_graph
.
ndata
[
'atomic_number'
]
assert
protein_graph_node_feats
.
shape
[
-
1
]
==
1
protein_graph_distances
=
protein_graph
.
edata
[
'distance'
]
protein_conv_out
=
self
.
protein_conv
(
protein_graph
,
protein_graph_node_feats
,
protein_graph_distances
)
complex_graph
=
graph
[:,
'complex'
,
:]
complex_graph_node_feats
=
complex_graph
.
ndata
[
'atomic_number'
]
assert
complex_graph_node_feats
.
shape
[
-
1
]
==
1
complex_graph_distances
=
complex_graph
.
edata
[
'distance'
]
complex_conv_out
=
self
.
complex_conv
(
complex_graph
,
complex_graph_node_feats
,
complex_graph_distances
)
frag1_node_indices_in_complex
=
torch
.
where
(
complex_graph
.
ndata
[
'_TYPE'
]
==
0
)[
0
]
frag2_node_indices_in_complex
=
list
(
set
(
range
(
complex_graph
.
number_of_nodes
()))
-
set
(
frag1_node_indices_in_complex
.
tolist
()))
# Hack the case when we are working with a single graph.
if
not
isinstance
(
graph
,
BatchedDGLHeteroGraph
):
graph
.
batch_size
=
1
return
self
.
predictor
(
graph
.
batch_size
,
frag1_node_indices_in_complex
,
frag2_node_indices_in_complex
,
ligand_conv_out
,
protein_conv_out
,
complex_conv_out
)
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment