Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5253 deletions
+0
-5253
apps/life_sci/python/dgllife/model/model_zoo/wln_reaction_center.py
...sci/python/dgllife/model/model_zoo/wln_reaction_center.py
+0
-177
apps/life_sci/python/dgllife/model/model_zoo/wln_reaction_ranking.py
...ci/python/dgllife/model/model_zoo/wln_reaction_ranking.py
+0
-139
apps/life_sci/python/dgllife/model/pretrain.py
apps/life_sci/python/dgllife/model/pretrain.py
+0
-194
apps/life_sci/python/dgllife/model/readout/__init__.py
apps/life_sci/python/dgllife/model/readout/__init__.py
+0
-8
apps/life_sci/python/dgllife/model/readout/attentivefp_readout.py
...e_sci/python/dgllife/model/readout/attentivefp_readout.py
+0
-134
apps/life_sci/python/dgllife/model/readout/mlp_readout.py
apps/life_sci/python/dgllife/model/readout/mlp_readout.py
+0
-67
apps/life_sci/python/dgllife/model/readout/weave_readout.py
apps/life_sci/python/dgllife/model/readout/weave_readout.py
+0
-116
apps/life_sci/python/dgllife/model/readout/weighted_sum_and_max.py
..._sci/python/dgllife/model/readout/weighted_sum_and_max.py
+0
-50
apps/life_sci/python/dgllife/utils/__init__.py
apps/life_sci/python/dgllife/utils/__init__.py
+0
-8
apps/life_sci/python/dgllife/utils/complex_to_graph.py
apps/life_sci/python/dgllife/utils/complex_to_graph.py
+0
-236
apps/life_sci/python/dgllife/utils/early_stop.py
apps/life_sci/python/dgllife/utils/early_stop.py
+0
-164
apps/life_sci/python/dgllife/utils/eval.py
apps/life_sci/python/dgllife/utils/eval.py
+0
-301
apps/life_sci/python/dgllife/utils/featurizers.py
apps/life_sci/python/dgllife/utils/featurizers.py
+0
-1654
apps/life_sci/python/dgllife/utils/mol_to_graph.py
apps/life_sci/python/dgllife/utils/mol_to_graph.py
+0
-799
apps/life_sci/python/dgllife/utils/rdkit_utils.py
apps/life_sci/python/dgllife/utils/rdkit_utils.py
+0
-187
apps/life_sci/python/dgllife/utils/splitters.py
apps/life_sci/python/dgllife/utils/splitters.py
+0
-756
apps/life_sci/python/setup.py
apps/life_sci/python/setup.py
+0
-53
apps/life_sci/python/update_version.py
apps/life_sci/python/update_version.py
+0
-51
apps/life_sci/tests/data/test_csv_dataset.py
apps/life_sci/tests/data/test_csv_dataset.py
+0
-69
apps/life_sci/tests/data/test_datasets.py
apps/life_sci/tests/data/test_datasets.py
+0
-90
No files found.
apps/life_sci/python/dgllife/model/model_zoo/wln_reaction_center.py
deleted
100644 → 0
View file @
94c67203
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
from
..gnn.wln
import
WLNLinear
,
WLN
__all__
=
[
'WLNReactionCenter'
]
# pylint: disable=W0221, E1101
class
WLNContext
(
nn
.
Module
):
"""Attention-based context computation for each node.
A context vector is computed by taking a weighted sum of node representations,
with weights computed from an attention module.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_pair_in_feats : int
Size for the input features of node pairs.
"""
def
__init__
(
self
,
node_in_feats
,
node_pair_in_feats
):
super
(
WLNContext
,
self
).
__init__
()
self
.
project_feature_sum
=
WLNLinear
(
node_in_feats
,
node_in_feats
,
bias
=
False
)
self
.
project_node_pair_feature
=
WLNLinear
(
node_pair_in_feats
,
node_in_feats
)
self
.
compute_attention
=
nn
.
Sequential
(
nn
.
ReLU
(),
WLNLinear
(
node_in_feats
,
1
),
nn
.
Sigmoid
()
)
def
forward
(
self
,
batch_complete_graphs
,
node_feats
,
feat_sum
,
node_pair_feat
):
"""Compute context vectors for each node.
Parameters
----------
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
feat_sum : float32 tensor of shape (E_full, node_in_feats)
Sum of node_feats between each pair of nodes. E_full for the number of
edges in the batch of complete graphs.
node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
node_contexts : float32 tensor of shape (V, node_in_feats)
Context vectors for nodes.
"""
with
batch_complete_graphs
.
local_scope
():
batch_complete_graphs
.
ndata
[
'hv'
]
=
node_feats
batch_complete_graphs
.
edata
[
'a'
]
=
self
.
compute_attention
(
self
.
project_feature_sum
(
feat_sum
)
+
\
self
.
project_node_pair_feature
(
node_pair_feat
)
)
batch_complete_graphs
.
update_all
(
fn
.
src_mul_edge
(
'hv'
,
'a'
,
'm'
),
fn
.
sum
(
'm'
,
'context'
))
node_contexts
=
batch_complete_graphs
.
ndata
.
pop
(
'context'
)
return
node_contexts
class
WLNReactionCenter
(
nn
.
Module
):
r
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
The model uses WLN to update atom representations and then predicts the
score for each pair of atoms to form a bond.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
node_pair_in_feats : int
Size for the input features of node pairs.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
n_tasks : int
Number of tasks for prediction.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_pair_in_feats
,
node_out_feats
=
300
,
n_layers
=
3
,
n_tasks
=
5
):
super
(
WLNReactionCenter
,
self
).
__init__
()
self
.
gnn
=
WLN
(
node_in_feats
=
node_in_feats
,
edge_in_feats
=
edge_in_feats
,
node_out_feats
=
node_out_feats
,
n_layers
=
n_layers
)
self
.
context_module
=
WLNContext
(
node_in_feats
=
node_out_feats
,
node_pair_in_feats
=
node_pair_in_feats
)
self
.
project_feature_sum
=
WLNLinear
(
node_out_feats
,
node_out_feats
,
bias
=
False
)
self
.
project_node_pair_feature
=
WLNLinear
(
node_pair_in_feats
,
node_out_feats
,
bias
=
False
)
self
.
project_context_sum
=
WLNLinear
(
node_out_feats
,
node_out_feats
)
self
.
predict
=
nn
.
Sequential
(
nn
.
ReLU
(),
WLNLinear
(
node_out_feats
,
n_tasks
)
)
def
forward
(
self
,
batch_mol_graphs
,
batch_complete_graphs
,
node_feats
,
edge_feats
,
node_pair_feats
):
r
"""Predict score for each pair of nodes.
Parameters
----------
batch_mol_graphs : DGLGraph
A batch of molecular graphs.
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
node_pair_feats : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats
=
self
.
gnn
(
batch_mol_graphs
,
node_feats
,
edge_feats
)
# Compute context vectors for all atoms, which are weighted sum of atom
# representations in all reactants.
with
batch_complete_graphs
.
local_scope
():
batch_complete_graphs
.
ndata
[
'hv'
]
=
node_feats
batch_complete_graphs
.
apply_edges
(
fn
.
u_add_v
(
'hv'
,
'hv'
,
'feature_sum'
))
feat_sum
=
batch_complete_graphs
.
edata
.
pop
(
'feature_sum'
)
node_contexts
=
self
.
context_module
(
batch_complete_graphs
,
node_feats
,
feat_sum
,
node_pair_feats
)
# Predict score
with
batch_complete_graphs
.
local_scope
():
batch_complete_graphs
.
ndata
[
'context'
]
=
node_contexts
batch_complete_graphs
.
apply_edges
(
fn
.
u_add_v
(
'context'
,
'context'
,
'context_sum'
))
scores
=
self
.
predict
(
self
.
project_feature_sum
(
feat_sum
)
+
\
self
.
project_node_pair_feature
(
node_pair_feats
)
+
\
self
.
project_context_sum
(
batch_complete_graphs
.
edata
[
'context_sum'
])
)
# Masking self loops
nodes
=
batch_complete_graphs
.
nodes
()
e_ids
=
batch_complete_graphs
.
edge_ids
(
nodes
,
nodes
)
bias
=
torch
.
zeros
(
scores
.
shape
[
0
],
5
).
to
(
scores
.
device
)
bias
[
e_ids
,
:]
=
1e4
biased_scores
=
scores
-
bias
return
scores
,
biased_scores
apps/life_sci/python/dgllife/model/model_zoo/wln_reaction_ranking.py
deleted
100644 → 0
View file @
94c67203
"""Weisfeiler-Lehman Network (WLN) for ranking candidate products"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch
import
torch.nn
as
nn
from
dgl.nn.pytorch
import
SumPooling
from
..gnn.wln
import
WLN
__all__
=
[
'WLNReactionRanking'
]
# pylint: disable=W0221, E1101
class
WLNReactionRanking
(
nn
.
Module
):
r
"""Weisfeiler-Lehman Network (WLN) for Candidate Product Ranking
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__ and then
further improved in `A graph-convolutional neural network model for the
prediction of chemical reactivity
<https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract>`__
The model updates representations of nodes in candidate products with WLN and predicts
the score for candidate products to be the real product.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_hidden_feats : int
Size for the hidden node representations. Default to 500.
num_encode_gnn_layers : int
Number of WLN layers for updating node representations.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_hidden_feats
=
500
,
num_encode_gnn_layers
=
3
):
super
(
WLNReactionRanking
,
self
).
__init__
()
self
.
gnn
=
WLN
(
node_in_feats
=
node_in_feats
,
edge_in_feats
=
edge_in_feats
,
node_out_feats
=
node_hidden_feats
,
n_layers
=
num_encode_gnn_layers
,
set_comparison
=
False
)
self
.
diff_gnn
=
WLN
(
node_in_feats
=
node_hidden_feats
,
edge_in_feats
=
edge_in_feats
,
node_out_feats
=
node_hidden_feats
,
n_layers
=
1
,
project_in_feats
=
False
,
set_comparison
=
False
)
self
.
readout
=
SumPooling
()
self
.
predict
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_feats
,
node_hidden_feats
),
nn
.
ReLU
(),
nn
.
Linear
(
node_hidden_feats
,
1
)
)
def
forward
(
self
,
reactant_graph
,
reactant_node_feats
,
reactant_edge_feats
,
product_graphs
,
product_node_feats
,
product_edge_feats
,
candidate_scores
,
batch_num_candidate_products
):
r
"""Predicts the score for candidate products to be the true product
Parameters
----------
reactant_graph : DGLGraph
DGLGraph for a batch of reactants.
reactant_node_feats : float32 tensor of shape (V1, node_in_feats)
Input node features for the reactants. V1 for the number of nodes.
reactant_edge_feats : float32 tensor of shape (E1, edge_in_feats)
Input edge features for the reactants. E1 for the number of edges in
reactant_graph.
product_graphs : DGLGraph
DGLGraph for the candidate products in a batch of reactions.
product_node_feats : float32 tensor of shape (V2, node_in_feats)
Input node features for the candidate products. V2 for the number of nodes.
product_edge_feats : float32 tensor of shape (E2, edge_in_feats)
Input edge features for the candidate products. E2 for the number of edges
in the graphs for candidate products.
candidate_scores : float32 tensor of shape (B, 1)
Scores for candidate products based on the model for reaction center prediction
batch_num_candidate_products : list of int
Number of candidate products for the reactions in the batch
Returns
-------
float32 tensor of shape (B, 1)
Predicted scores for candidate products
"""
# Update representations for nodes in both reactants and candidate products
batch_reactant_node_feats
=
self
.
gnn
(
reactant_graph
,
reactant_node_feats
,
reactant_edge_feats
)
batch_product_node_feats
=
self
.
gnn
(
product_graphs
,
product_node_feats
,
product_edge_feats
)
# Iterate over the reactions in the batch
reactant_node_start
=
0
product_graph_start
=
0
product_node_start
=
0
batch_diff_node_feats
=
[]
for
i
,
num_candidate_products
in
enumerate
(
batch_num_candidate_products
):
reactant_node_end
=
reactant_node_start
+
reactant_graph
.
batch_num_nodes
[
i
]
product_graph_end
=
product_graph_start
+
num_candidate_products
product_node_end
=
product_node_start
+
sum
(
product_graphs
.
batch_num_nodes
[
product_graph_start
:
product_graph_end
])
# (N, node_out_feats)
reactant_node_feats
=
batch_reactant_node_feats
[
reactant_node_start
:
reactant_node_end
,
:]
product_node_feats
=
batch_product_node_feats
[
product_node_start
:
product_node_end
,
:]
old_feats_shape
=
reactant_node_feats
.
shape
# (1, N, node_out_feats)
expanded_reactant_node_feats
=
reactant_node_feats
.
reshape
((
1
,)
+
old_feats_shape
)
# (B, N, node_out_feats)
expanded_reactant_node_feats
=
expanded_reactant_node_feats
.
expand
(
(
num_candidate_products
,)
+
old_feats_shape
)
# (B, N, node_out_feats)
candidate_product_node_feats
=
product_node_feats
.
reshape
(
(
num_candidate_products
,)
+
old_feats_shape
)
# Get the node representation difference between candidate products and reactants
diff_node_feats
=
candidate_product_node_feats
-
expanded_reactant_node_feats
diff_node_feats
=
diff_node_feats
.
reshape
(
-
1
,
diff_node_feats
.
shape
[
-
1
])
batch_diff_node_feats
.
append
(
diff_node_feats
)
reactant_node_start
=
reactant_node_end
product_graph_start
=
product_graph_end
product_node_start
=
product_node_end
batch_diff_node_feats
=
torch
.
cat
(
batch_diff_node_feats
,
dim
=
0
)
# One more GNN layer for message passing with the node representation difference
diff_node_feats
=
self
.
diff_gnn
(
product_graphs
,
batch_diff_node_feats
,
product_edge_feats
)
candidate_product_feats
=
self
.
readout
(
product_graphs
,
diff_node_feats
)
return
self
.
predict
(
candidate_product_feats
)
+
candidate_scores
apps/life_sci/python/dgllife/model/pretrain.py
deleted
100644 → 0
View file @
94c67203
"""Utilities for using pretrained models."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
os
import
torch
import
torch.nn.functional
as
F
from
dgl.data.utils
import
_get_dgl_url
,
download
,
get_download_dir
,
extract_archive
from
rdkit
import
Chem
from
..model
import
GCNPredictor
,
GATPredictor
,
AttentiveFPPredictor
,
DGMG
,
DGLJTNNVAE
,
\
WLNReactionCenter
,
WLNReactionRanking
,
WeavePredictor
,
GIN
__all__
=
[
'load_pretrained'
]
URL
=
{
'GCN_Tox21'
:
'dgllife/pre_trained/gcn_tox21.pth'
,
'GAT_Tox21'
:
'dgllife/pre_trained/gat_tox21.pth'
,
'Weave_Tox21'
:
'dgllife/pre_trained/weave_tox21.pth'
,
'AttentiveFP_Aromaticity'
:
'dgllife/pre_trained/attentivefp_aromaticity.pth'
,
'DGMG_ChEMBL_canonical'
:
'pre_trained/dgmg_ChEMBL_canonical.pth'
,
'DGMG_ChEMBL_random'
:
'pre_trained/dgmg_ChEMBL_random.pth'
,
'DGMG_ZINC_canonical'
:
'pre_trained/dgmg_ZINC_canonical.pth'
,
'DGMG_ZINC_random'
:
'pre_trained/dgmg_ZINC_random.pth'
,
'JTNN_ZINC'
:
'pre_trained/JTNN_ZINC.pth'
,
'wln_center_uspto'
:
'dgllife/pre_trained/wln_center_uspto_v3.pth'
,
'wln_rank_uspto'
:
'dgllife/pre_trained/wln_rank_uspto.pth'
,
'gin_supervised_contextpred'
:
'dgllife/pre_trained/gin_supervised_contextpred.pth'
,
'gin_supervised_infomax'
:
'dgllife/pre_trained/gin_supervised_infomax.pth'
,
'gin_supervised_edgepred'
:
'dgllife/pre_trained/gin_supervised_edgepred.pth'
,
'gin_supervised_masking'
:
'dgllife/pre_trained/gin_supervised_masking.pth'
}
def
download_and_load_checkpoint
(
model_name
,
model
,
model_postfix
,
local_pretrained_path
=
'pre_trained.pth'
,
log
=
True
):
"""Download pretrained model checkpoint
The model will be loaded to CPU.
Parameters
----------
model_name : str
Name of the model
model : nn.Module
Instantiated model instance
model_postfix : str
Postfix for pretrained model checkpoint
local_pretrained_path : str
Local name for the downloaded model checkpoint
log : bool
Whether to print progress for model loading
Returns
-------
model : nn.Module
Pretrained model
"""
url_to_pretrained
=
_get_dgl_url
(
model_postfix
)
local_pretrained_path
=
'_'
.
join
([
model_name
,
local_pretrained_path
])
download
(
url_to_pretrained
,
path
=
local_pretrained_path
,
log
=
log
)
checkpoint
=
torch
.
load
(
local_pretrained_path
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
checkpoint
[
'model_state_dict'
])
if
log
:
print
(
'Pretrained model loaded'
)
return
model
# pylint: disable=I1101
def
load_pretrained
(
model_name
,
log
=
True
):
"""Load a pretrained model
Parameters
----------
model_name : str
Currently supported options include
* ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'Weave_Tox21'``: A Weave model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
atom order
* ``'DGMG_ChEMBL_random'``: A DGMG model trained on ChEMBL for molecule generation
with a random atom order
* ``'DGMG_ZINC_canonical'``: A DGMG model trained on ZINC for molecule generation
with a canonical atom order
* ``'DGMG_ZINC_random'``: A DGMG model pre-trained on ZINC for molecule generation
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
* ``'wln_rank_uspto'``: A WLN model pre-trained on USPTO for candidate product ranking
* ``'gin_supervised_contextpred'``: A GIN model pre-trained with supervised learning
and context prediction
* ``'gin_supervised_infomax'``: A GIN model pre-trained with supervised learning
and deep graph infomax
* ``'gin_supervised_edgepred'``: A GIN model pre-trained with supervised learning
and edge prediction
* ``'gin_supervised_masking'``: A GIN model pre-trained with supervised learning
and attribute masking
log : bool
Whether to print progress for model loading
Returns
-------
model
"""
if
model_name
not
in
URL
:
raise
RuntimeError
(
"Cannot find a pretrained model with name {}"
.
format
(
model_name
))
if
model_name
==
'GCN_Tox21'
:
model
=
GCNPredictor
(
in_feats
=
74
,
hidden_feats
=
[
64
,
64
],
classifier_hidden_feats
=
64
,
n_tasks
=
12
)
elif
model_name
==
'GAT_Tox21'
:
model
=
GATPredictor
(
in_feats
=
74
,
hidden_feats
=
[
32
,
32
],
num_heads
=
[
4
,
4
],
agg_modes
=
[
'flatten'
,
'mean'
],
activations
=
[
F
.
elu
,
None
],
classifier_hidden_feats
=
64
,
n_tasks
=
12
)
elif
model_name
==
'Weave_Tox21'
:
model
=
WeavePredictor
(
node_in_feats
=
27
,
edge_in_feats
=
7
,
num_gnn_layers
=
2
,
gnn_hidden_feats
=
50
,
graph_feats
=
128
,
n_tasks
=
12
)
elif
model_name
==
'AttentiveFP_Aromaticity'
:
model
=
AttentiveFPPredictor
(
node_feat_size
=
39
,
edge_feat_size
=
10
,
num_layers
=
2
,
num_timesteps
=
2
,
graph_feat_size
=
200
,
n_tasks
=
1
,
dropout
=
0.2
)
elif
model_name
.
startswith
(
'DGMG'
):
if
model_name
.
startswith
(
'DGMG_ChEMBL'
):
atom_types
=
[
'O'
,
'Cl'
,
'C'
,
'S'
,
'F'
,
'Br'
,
'N'
]
elif
model_name
.
startswith
(
'DGMG_ZINC'
):
atom_types
=
[
'Br'
,
'S'
,
'C'
,
'P'
,
'N'
,
'O'
,
'F'
,
'Cl'
,
'I'
]
bond_types
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
]
model
=
DGMG
(
atom_types
=
atom_types
,
bond_types
=
bond_types
,
node_hidden_size
=
128
,
num_prop_rounds
=
2
,
dropout
=
0.2
)
elif
model_name
==
"JTNN_ZINC"
:
default_dir
=
get_download_dir
()
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
default_dir
,
'vocab'
)
if
not
os
.
path
.
exists
(
vocab_file
):
zip_file_path
=
'{}/jtnn.zip'
.
format
(
default_dir
)
download
(
_get_dgl_url
(
'dgllife/jtnn.zip'
),
path
=
zip_file_path
)
extract_archive
(
zip_file_path
,
'{}/jtnn'
.
format
(
default_dir
))
model
=
DGLJTNNVAE
(
vocab_file
=
vocab_file
,
depth
=
3
,
hidden_size
=
450
,
latent_size
=
56
)
elif
model_name
==
'wln_center_uspto'
:
model
=
WLNReactionCenter
(
node_in_feats
=
82
,
edge_in_feats
=
6
,
node_pair_in_feats
=
10
,
node_out_feats
=
300
,
n_layers
=
3
,
n_tasks
=
5
)
elif
model_name
==
'wln_rank_uspto'
:
model
=
WLNReactionRanking
(
node_in_feats
=
89
,
edge_in_feats
=
5
,
node_hidden_feats
=
500
,
num_encode_gnn_layers
=
3
)
elif
model_name
in
[
'gin_supervised_contextpred'
,
'gin_supervised_infomax'
,
'gin_supervised_edgepred'
,
'gin_supervised_masking'
]:
model
=
GIN
(
num_node_emb_list
=
[
120
,
3
],
num_edge_emb_list
=
[
6
,
3
],
num_layers
=
5
,
emb_dim
=
300
,
JK
=
'last'
,
dropout
=
0.5
)
return
download_and_load_checkpoint
(
model_name
,
model
,
URL
[
model_name
],
log
=
log
)
apps/life_sci/python/dgllife/model/readout/__init__.py
deleted
100644 → 0
View file @
94c67203
"""
Readout functions for computing molecular representations
out of node and edge representations.
"""
from
.attentivefp_readout
import
*
from
.weighted_sum_and_max
import
*
from
.mlp_readout
import
*
from
.weave_readout
import
*
apps/life_sci/python/dgllife/model/readout/attentivefp_readout.py
deleted
100644 → 0
View file @
94c67203
"""Readout for AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
__all__
=
[
'AttentiveFPReadout'
]
# pylint: disable=W0221
class
GlobalPool
(
nn
.
Module
):
"""One-step readout in AttentiveFP
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
feat_size
,
dropout
):
super
(
GlobalPool
,
self
).
__init__
()
self
.
compute_logits
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
project_nodes
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
feat_size
,
feat_size
)
)
self
.
gru
=
nn
.
GRUCell
(
feat_size
,
feat_size
)
def
forward
(
self
,
g
,
node_feats
,
g_feats
,
get_node_weight
=
False
):
"""Perform one-step readout
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
g_feats : float32 tensor of shape (G, graph_feat_size)
Input graph features. G for the number of graphs.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, graph_feat_size)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with
g
.
local_scope
():
g
.
ndata
[
'z'
]
=
self
.
compute_logits
(
torch
.
cat
([
dgl
.
broadcast_nodes
(
g
,
F
.
relu
(
g_feats
)),
node_feats
],
dim
=
1
))
g
.
ndata
[
'a'
]
=
dgl
.
softmax_nodes
(
g
,
'z'
)
g
.
ndata
[
'hv'
]
=
self
.
project_nodes
(
node_feats
)
g_repr
=
dgl
.
sum_nodes
(
g
,
'hv'
,
'a'
)
context
=
F
.
elu
(
g_repr
)
if
get_node_weight
:
return
self
.
gru
(
context
,
g_feats
),
g
.
ndata
[
'a'
]
else
:
return
self
.
gru
(
context
,
g_feats
)
class
AttentiveFPReadout
(
nn
.
Module
):
"""Readout in AttentiveFP
AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for
Drug Discovery with the Graph Attention Mechanism
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class computes graph representations out of node features.
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
dropout : float
The probability for performing dropout. Default to 0.
"""
def
__init__
(
self
,
feat_size
,
num_timesteps
=
2
,
dropout
=
0.
):
super
(
AttentiveFPReadout
,
self
).
__init__
()
self
.
readouts
=
nn
.
ModuleList
()
for
_
in
range
(
num_timesteps
):
self
.
readouts
.
append
(
GlobalPool
(
feat_size
,
dropout
))
def
forward
(
self
,
g
,
node_feats
,
get_node_weight
=
False
):
"""Computes graph representations out of node features.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
get_node_weight : bool
Whether to get the weights of nodes in readout. Default to False.
Returns
-------
g_feats : float32 tensor of shape (G, graph_feat_size)
Graph representations computed. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
"""
with
g
.
local_scope
():
g
.
ndata
[
'hv'
]
=
node_feats
g_feats
=
dgl
.
sum_nodes
(
g
,
'hv'
)
if
get_node_weight
:
node_weights
=
[]
for
readout
in
self
.
readouts
:
if
get_node_weight
:
g_feats
,
node_weights_t
=
readout
(
g
,
node_feats
,
g_feats
,
get_node_weight
)
node_weights
.
append
(
node_weights_t
)
else
:
g_feats
=
readout
(
g
,
node_feats
,
g_feats
)
if
get_node_weight
:
return
g_feats
,
node_weights
else
:
return
g_feats
apps/life_sci/python/dgllife/model/readout/mlp_readout.py
deleted
100644 → 0
View file @
94c67203
"""Readout for SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl
import
torch.nn
as
nn
__all__
=
[
'MLPNodeReadout'
]
# pylint: disable=W0221
class
MLPNodeReadout
(
nn
.
Module
):
"""MLP-based Readout.
This layer updates node representations with a MLP and computes graph representations
out of node representations with max, mean or sum.
Parameters
----------
node_feats : int
Size for the input node features.
hidden_feats : int
Size for the hidden representations.
graph_feats : int
Size for the output graph representations.
activation : callable
Activation function. Default to None.
mode : 'max' or 'mean' or 'sum'
Whether to compute elementwise maximum, mean or sum of the node representations.
"""
def
__init__
(
self
,
node_feats
,
hidden_feats
,
graph_feats
,
activation
=
None
,
mode
=
'sum'
):
super
(
MLPNodeReadout
,
self
).
__init__
()
assert
mode
in
[
'max'
,
'mean'
,
'sum'
],
\
"Expect mode to be 'max' or 'mean' or 'sum', got {}"
.
format
(
mode
)
self
.
mode
=
mode
self
.
in_project
=
nn
.
Linear
(
node_feats
,
hidden_feats
)
self
.
activation
=
activation
self
.
out_project
=
nn
.
Linear
(
hidden_feats
,
graph_feats
)
def
forward
(
self
,
g
,
node_feats
):
"""Computes graph representations out of node features.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
Returns
-------
graph_feats : float32 tensor of shape (G, graph_feats)
Graph representations computed. G for the number of graphs.
"""
node_feats
=
self
.
in_project
(
node_feats
)
if
self
.
activation
is
not
None
:
node_feats
=
self
.
activation
(
node_feats
)
node_feats
=
self
.
out_project
(
node_feats
)
with
g
.
local_scope
():
g
.
ndata
[
'h'
]
=
node_feats
if
self
.
mode
==
'max'
:
graph_feats
=
dgl
.
max_nodes
(
g
,
'h'
)
elif
self
.
mode
==
'mean'
:
graph_feats
=
dgl
.
mean_nodes
(
g
,
'h'
)
elif
self
.
mode
==
'sum'
:
graph_feats
=
dgl
.
sum_nodes
(
g
,
'h'
)
return
graph_feats
apps/life_sci/python/dgllife/model/readout/weave_readout.py
deleted
100644 → 0
View file @
94c67203
"""Readout for Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl
import
torch
import
torch.nn
as
nn
from
torch.distributions
import
Normal
__all__
=
[
'WeaveGather'
]
# pylint: disable=W0221, E1101, E1102
class
WeaveGather
(
nn
.
Module
):
r
"""Readout in Weave
Parameters
----------
node_in_feats : int
Size for the input node features.
gaussian_expand : bool
Whether to expand each dimension of node features by gaussian histogram.
Default to True.
gaussian_memberships : list of 2-tuples
For each tuple, the first and second element separately specifies the mean
and std for constructing a normal distribution. This argument comes into
effect only when ``gaussian_expand==True``. By default, we set this to be
``[(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]``.
activation : callable
Activation function to apply. Default to tanh.
"""
def
__init__
(
self
,
node_in_feats
,
gaussian_expand
=
True
,
gaussian_memberships
=
None
,
activation
=
nn
.
Tanh
()):
super
(
WeaveGather
,
self
).
__init__
()
self
.
gaussian_expand
=
gaussian_expand
if
gaussian_expand
:
if
gaussian_memberships
is
None
:
gaussian_memberships
=
[
(
-
1.645
,
0.283
),
(
-
1.080
,
0.170
),
(
-
0.739
,
0.134
),
(
-
0.468
,
0.118
),
(
-
0.228
,
0.114
),
(
0.
,
0.114
),
(
0.228
,
0.114
),
(
0.468
,
0.118
),
(
0.739
,
0.134
),
(
1.080
,
0.170
),
(
1.645
,
0.283
)]
means
,
stds
=
map
(
list
,
zip
(
*
gaussian_memberships
))
self
.
means
=
nn
.
ParameterList
([
nn
.
Parameter
(
torch
.
tensor
(
value
),
requires_grad
=
False
)
for
value
in
means
])
self
.
stds
=
nn
.
ParameterList
([
nn
.
Parameter
(
torch
.
tensor
(
value
),
requires_grad
=
False
)
for
value
in
stds
])
self
.
to_out
=
nn
.
Linear
(
node_in_feats
*
len
(
self
.
means
),
node_in_feats
)
self
.
activation
=
activation
def
gaussian_histogram
(
self
,
node_feats
):
r
"""Constructs a gaussian histogram to capture the distribution of features
Parameters
----------
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
float32 tensor of shape (V, node_in_feats * len(self.means))
Updated node representations
"""
gaussian_dists
=
[
Normal
(
self
.
means
[
i
],
self
.
stds
[
i
])
for
i
in
range
(
len
(
self
.
means
))]
max_log_probs
=
[
gaussian_dists
[
i
].
log_prob
(
self
.
means
[
i
])
for
i
in
range
(
len
(
self
.
means
))]
# Normalize the probabilities by the maximum point-wise probabilities,
# whose results will be in range [0, 1]. Note that division of probabilities
# is equivalent to subtraction of log probabilities and the latter one is cheaper.
log_probs
=
[
gaussian_dists
[
i
].
log_prob
(
node_feats
)
-
max_log_probs
[
i
]
for
i
in
range
(
len
(
self
.
means
))]
probs
=
torch
.
stack
(
log_probs
,
dim
=
2
).
exp
()
# (V, node_in_feats, len(self.means))
# Add a bias to avoid numerical issues in division
probs
=
probs
+
1e-7
# Normalize the probabilities across all Gaussian distributions
probs
=
probs
/
probs
.
sum
(
2
,
keepdim
=
True
)
return
probs
.
reshape
(
node_feats
.
shape
[
0
],
node_feats
.
shape
[
1
]
*
len
(
self
.
means
))
def
forward
(
self
,
g
,
node_feats
):
r
"""Computes graph representations out of node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
g_feats : float32 tensor of shape (G, node_in_feats)
Output graph representations. G for the number of graphs in the batch.
"""
if
self
.
gaussian_expand
:
node_feats
=
self
.
gaussian_histogram
(
node_feats
)
with
g
.
local_scope
():
g
.
ndata
[
'h'
]
=
node_feats
g_feats
=
dgl
.
sum_nodes
(
g
,
'h'
)
if
self
.
gaussian_expand
:
g_feats
=
self
.
to_out
(
g_feats
)
if
self
.
activation
is
not
None
:
g_feats
=
self
.
activation
(
g_feats
)
return
g_feats
apps/life_sci/python/dgllife/model/readout/weighted_sum_and_max.py
deleted
100644 → 0
View file @
94c67203
"""Apply weighted sum and max pooling to the node representations and concatenate the results."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl
import
torch
import
torch.nn
as
nn
from
dgl.nn.pytorch
import
WeightAndSum
__all__
=
[
'WeightedSumAndMax'
]
# pylint: disable=W0221
class
WeightedSumAndMax
(
nn
.
Module
):
r
"""Apply weighted sum and max pooling to the node
representations and concatenate the results.
Parameters
----------
in_feats : int
Input node feature size
"""
def
__init__
(
self
,
in_feats
):
super
(
WeightedSumAndMax
,
self
).
__init__
()
self
.
weight_and_sum
=
WeightAndSum
(
in_feats
)
def
forward
(
self
,
bg
,
feats
):
"""Readout
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
h_g : FloatTensor of shape (B, 2 * M1)
* B is the number of graphs in the batch
* M1 is the input node feature size, which must match
in_feats in initialization
"""
h_g_sum
=
self
.
weight_and_sum
(
bg
,
feats
)
with
bg
.
local_scope
():
bg
.
ndata
[
'h'
]
=
feats
h_g_max
=
dgl
.
max_nodes
(
bg
,
'h'
)
h_g
=
torch
.
cat
([
h_g_sum
,
h_g_max
],
dim
=
1
)
return
h_g
apps/life_sci/python/dgllife/utils/__init__.py
deleted
100644 → 0
View file @
94c67203
"""Utils for data processing."""
from
.complex_to_graph
import
*
from
.early_stop
import
*
from
.eval
import
*
from
.featurizers
import
*
from
.mol_to_graph
import
*
from
.rdkit_utils
import
*
from
.splitters
import
*
apps/life_sci/python/dgllife/utils/complex_to_graph.py
deleted
100644 → 0
View file @
94c67203
"""Convert complexes into DGLHeteroGraphs"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl.backend
as
F
import
numpy
as
np
from
dgl
import
graph
,
bipartite
,
hetero_from_relations
from
..utils.mol_to_graph
import
k_nearest_neighbors
__all__
=
[
'ACNN_graph_construction_and_featurization'
]
def
filter_out_hydrogens
(
mol
):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left
=
[]
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
atomic_num
=
atom
.
GetAtomicNum
()
# Hydrogen atoms have an atomic number of 1.
if
atomic_num
!=
1
:
indices_left
.
append
(
i
)
return
indices_left
def
get_atomic_numbers
(
mol
,
indices
):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers
=
[]
for
i
in
indices
:
atom
=
mol
.
GetAtomWithIdx
(
i
)
atomic_numbers
.
append
(
atom
.
GetAtomicNum
())
return
atomic_numbers
# pylint: disable=C0326
def
ACNN_graph_construction_and_featurization
(
ligand_mol
,
protein_mol
,
ligand_coordinates
,
protein_coordinates
,
max_num_ligand_atoms
=
None
,
max_num_protein_atoms
=
None
,
neighbor_cutoff
=
12.
,
max_num_neighbors
=
12
,
strip_hydrogens
=
False
):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert
ligand_coordinates
is
not
None
,
'Expect ligand_coordinates to be provided.'
assert
protein_coordinates
is
not
None
,
'Expect protein_coordinates to be provided.'
if
max_num_ligand_atoms
is
not
None
:
assert
max_num_ligand_atoms
>=
ligand_mol
.
GetNumAtoms
(),
\
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), '
\
'got {:d} and {:d}'
.
format
(
max_num_ligand_atoms
,
ligand_mol
.
GetNumAtoms
())
if
max_num_protein_atoms
is
not
None
:
assert
max_num_protein_atoms
>=
protein_mol
.
GetNumAtoms
(),
\
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), '
\
'got {:d} and {:d}'
.
format
(
max_num_protein_atoms
,
protein_mol
.
GetNumAtoms
())
if
strip_hydrogens
:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left
=
filter_out_hydrogens
(
ligand_mol
)
protein_atom_indices_left
=
filter_out_hydrogens
(
protein_mol
)
ligand_coordinates
=
ligand_coordinates
.
take
(
ligand_atom_indices_left
,
axis
=
0
)
protein_coordinates
=
protein_coordinates
.
take
(
protein_atom_indices_left
,
axis
=
0
)
else
:
ligand_atom_indices_left
=
list
(
range
(
ligand_mol
.
GetNumAtoms
()))
protein_atom_indices_left
=
list
(
range
(
protein_mol
.
GetNumAtoms
()))
# Compute number of nodes for each type
if
max_num_ligand_atoms
is
None
:
num_ligand_atoms
=
len
(
ligand_atom_indices_left
)
else
:
num_ligand_atoms
=
max_num_ligand_atoms
if
max_num_protein_atoms
is
None
:
num_protein_atoms
=
len
(
protein_atom_indices_left
)
else
:
num_protein_atoms
=
max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs
,
ligand_dsts
,
ligand_dists
=
k_nearest_neighbors
(
ligand_coordinates
,
neighbor_cutoff
,
max_num_neighbors
)
ligand_graph
=
graph
((
ligand_srcs
,
ligand_dsts
),
'ligand_atom'
,
'ligand'
,
num_ligand_atoms
)
ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
np
.
array
(
ligand_dists
).
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Construct graph for atoms in the protein
protein_srcs
,
protein_dsts
,
protein_dists
=
k_nearest_neighbors
(
protein_coordinates
,
neighbor_cutoff
,
max_num_neighbors
)
protein_graph
=
graph
((
protein_srcs
,
protein_dsts
),
'protein_atom'
,
'protein'
,
num_protein_atoms
)
protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
np
.
array
(
protein_dists
).
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs
,
complex_dsts
,
complex_dists
=
k_nearest_neighbors
(
np
.
concatenate
([
ligand_coordinates
,
protein_coordinates
]),
neighbor_cutoff
,
max_num_neighbors
)
complex_srcs
=
np
.
array
(
complex_srcs
)
complex_dsts
=
np
.
array
(
complex_dsts
)
complex_dists
=
np
.
array
(
complex_dists
)
offset
=
num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices
=
np
.
intersect1d
(
(
complex_srcs
<
offset
).
nonzero
()[
0
],
(
complex_dsts
<
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
inter_ligand_graph
=
graph
(
(
complex_srcs
[
inter_ligand_indices
].
tolist
(),
complex_dsts
[
inter_ligand_indices
].
tolist
()),
'ligand_atom'
,
'complex'
,
num_ligand_atoms
)
inter_ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
inter_ligand_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices
=
np
.
intersect1d
(
(
complex_srcs
>=
offset
).
nonzero
()[
0
],
(
complex_dsts
>=
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
inter_protein_graph
=
graph
(
((
complex_srcs
[
inter_protein_indices
]
-
offset
).
tolist
(),
(
complex_dsts
[
inter_protein_indices
]
-
offset
).
tolist
()),
'protein_atom'
,
'complex'
,
num_protein_atoms
)
inter_protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
inter_protein_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices
=
np
.
intersect1d
(
(
complex_srcs
<
offset
).
nonzero
()[
0
],
(
complex_dsts
>=
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
ligand_protein_graph
=
bipartite
(
(
complex_srcs
[
ligand_protein_indices
].
tolist
(),
(
complex_dsts
[
ligand_protein_indices
]
-
offset
).
tolist
()),
'ligand_atom'
,
'complex'
,
'protein_atom'
,
(
num_ligand_atoms
,
num_protein_atoms
))
ligand_protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
ligand_protein_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices
=
np
.
intersect1d
(
(
complex_srcs
>=
offset
).
nonzero
()[
0
],
(
complex_dsts
<
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
protein_ligand_graph
=
bipartite
(
((
complex_srcs
[
protein_ligand_indices
]
-
offset
).
tolist
(),
complex_dsts
[
protein_ligand_indices
].
tolist
()),
'protein_atom'
,
'complex'
,
'ligand_atom'
,
(
num_protein_atoms
,
num_ligand_atoms
))
protein_ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
protein_ligand_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Merge the graphs
g
=
hetero_from_relations
(
[
protein_graph
,
ligand_graph
,
inter_ligand_graph
,
inter_protein_graph
,
ligand_protein_graph
,
protein_ligand_graph
]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers
=
np
.
array
(
get_atomic_numbers
(
ligand_mol
,
ligand_atom_indices_left
))
# zero padding
ligand_atomic_numbers
=
np
.
concatenate
([
ligand_atomic_numbers
,
np
.
zeros
(
num_ligand_atoms
-
len
(
ligand_atom_indices_left
))])
protein_atomic_numbers
=
np
.
array
(
get_atomic_numbers
(
protein_mol
,
protein_atom_indices_left
))
# zero padding
protein_atomic_numbers
=
np
.
concatenate
([
protein_atomic_numbers
,
np
.
zeros
(
num_protein_atoms
-
len
(
protein_atom_indices_left
))])
g
.
nodes
[
'ligand_atom'
].
data
[
'atomic_number'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
ligand_atomic_numbers
.
astype
(
np
.
float32
)),
(
-
1
,
1
))
g
.
nodes
[
'protein_atom'
].
data
[
'atomic_number'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
protein_atomic_numbers
.
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Prepare mask indicating the existence of nodes
ligand_masks
=
np
.
zeros
((
num_ligand_atoms
,
1
))
ligand_masks
[:
len
(
ligand_atom_indices_left
),
:]
=
1
g
.
nodes
[
'ligand_atom'
].
data
[
'mask'
]
=
F
.
zerocopy_from_numpy
(
ligand_masks
.
astype
(
np
.
float32
))
protein_masks
=
np
.
zeros
((
num_protein_atoms
,
1
))
protein_masks
[:
len
(
protein_atom_indices_left
),
:]
=
1
g
.
nodes
[
'protein_atom'
].
data
[
'mask'
]
=
F
.
zerocopy_from_numpy
(
protein_masks
.
astype
(
np
.
float32
))
return
g
apps/life_sci/python/dgllife/utils/early_stop.py
deleted
100644 → 0
View file @
94c67203
"""Early stopping"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
datetime
import
torch
__all__
=
[
'EarlyStopping'
]
# pylint: disable=C0103
class
EarlyStopping
(
object
):
"""Early stop tracker
Save model checkpoint when observing a performance improvement on
the validation set and early stop if improvement has not been
observed for a particular number of epochs.
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs.
filename : str or None
Filename for storing the model checkpoint. If not specified,
we will automatically generate a file starting with ``early_stop``
based on the current time.
Examples
--------
Below gives a demo for a fake training process.
>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>> x = torch.randn(1, 1) # Fake input
>>> y = torch.randn(1, 1) # Fake label
>>> pred = model(x)
>>> loss = criterion(y, pred)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> early_stop = stopper.step(loss.detach().data, model)
>>> if early_stop:
>>> break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
"""
def
__init__
(
self
,
mode
=
'higher'
,
patience
=
10
,
filename
=
None
):
if
filename
is
None
:
dt
=
datetime
.
datetime
.
now
()
filename
=
'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'
.
format
(
dt
.
date
(),
dt
.
hour
,
dt
.
minute
,
dt
.
second
)
assert
mode
in
[
'higher'
,
'lower'
]
self
.
mode
=
mode
if
self
.
mode
==
'higher'
:
self
.
_check
=
self
.
_check_higher
else
:
self
.
_check
=
self
.
_check_lower
self
.
patience
=
patience
self
.
counter
=
0
self
.
filename
=
filename
self
.
best_score
=
None
self
.
early_stop
=
False
def
_check_higher
(
self
,
score
,
prev_best_score
):
"""Check if the new score is higher than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is higher than the previous best score.
"""
return
score
>
prev_best_score
def
_check_lower
(
self
,
score
,
prev_best_score
):
"""Check if the new score is lower than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is lower than the previous best score.
"""
return
score
<
prev_best_score
def
step
(
self
,
score
,
model
):
"""Update based on a new score.
The new score is typically model performance on the validation set
for a new epoch.
Parameters
----------
score : float
New score.
model : nn.Module
Model instance.
Returns
-------
bool
Whether an early stop should be performed.
"""
if
self
.
best_score
is
None
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
elif
self
.
_check
(
score
,
self
.
best_score
):
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
self
.
counter
=
0
else
:
self
.
counter
+=
1
print
(
f
'EarlyStopping counter:
{
self
.
counter
}
out of
{
self
.
patience
}
'
)
if
self
.
counter
>=
self
.
patience
:
self
.
early_stop
=
True
return
self
.
early_stop
def
save_checkpoint
(
self
,
model
):
'''Saves model when the metric on the validation set gets improved.
Parameters
----------
model : nn.Module
Model instance.
'''
torch
.
save
({
'model_state_dict'
:
model
.
state_dict
()},
self
.
filename
)
def
load_checkpoint
(
self
,
model
):
'''Load the latest checkpoint
Parameters
----------
model : nn.Module
Model instance.
'''
model
.
load_state_dict
(
torch
.
load
(
self
.
filename
)[
'model_state_dict'
])
apps/life_sci/python/dgllife/utils/eval.py
deleted
100644 → 0
View file @
94c67203
"""Evaluation of model performance."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
scipy.stats
import
pearsonr
from
sklearn.metrics
import
roc_auc_score
__all__
=
[
'Meter'
]
# pylint: disable=E1101
class
Meter
(
object
):
"""Track and summarize model performance on a dataset for (multi-label) prediction.
When dealing with multitask learning, quite often we normalize the labels so they are
roughly at a same scale. During the evaluation, we need to undo the normalization on
the predicted labels. If mean and std are not None, we will undo the normalization.
Currently we support evaluation with 4 metrics:
* ``pearson r2``
* ``mae``
* ``rmse``
* ``roc auc score``
Parameters
----------
mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not ``None``. ``T`` for the
number of tasks. Default to ``None`` and we assume no label normalization has been
performed.
std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not ``None``. Default to ``None``
and we assume no label normalization has been performed.
Examples
--------
Below gives a demo for a fake evaluation epoch.
>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>> batch_label = torch.randn(3, 3)
>>> batch_pred = torch.randn(3, 3)
>>> meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
"""
def
__init__
(
self
,
mean
=
None
,
std
=
None
):
self
.
mask
=
[]
self
.
y_pred
=
[]
self
.
y_true
=
[]
if
(
mean
is
not
None
)
and
(
std
is
not
None
):
self
.
mean
=
mean
.
cpu
()
self
.
std
=
std
.
cpu
()
else
:
self
.
mean
=
None
self
.
std
=
None
def
update
(
self
,
y_pred
,
y_true
,
mask
=
None
):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted labels with shape ``(B, T)``,
``B`` for number of graphs in the batch and ``T`` for the number of tasks
y_true : float32 tensor
Ground truth labels with shape ``(B, T)``
mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with
shape ``(B, T)``. If None, we assume that all labels exist and create
a one-tensor for placeholder.
"""
self
.
y_pred
.
append
(
y_pred
.
detach
().
cpu
())
self
.
y_true
.
append
(
y_true
.
detach
().
cpu
())
if
mask
is
None
:
self
.
mask
.
append
(
torch
.
ones
(
self
.
y_pred
[
-
1
].
shape
))
else
:
self
.
mask
.
append
(
mask
.
detach
().
cpu
())
def
_finalize
(
self
):
"""Prepare for evaluation.
If normalization was performed on the ground truth labels during training,
we need to undo the normalization on the predicted labels.
Returns
-------
mask : float32 tensor
Binary mask indicating the existence of ground
truth labels with shape (B, T), B for batch size
and T for the number of tasks
y_pred : float32 tensor
Predicted labels with shape (B, T)
y_true : float32 tensor
Ground truth labels with shape (B, T)
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
if
(
self
.
mean
is
not
None
)
and
(
self
.
std
is
not
None
):
# To compensate for the imbalance between labels during training,
# we normalize the ground truth labels with training mean and std.
# We need to undo that for evaluation.
y_pred
=
y_pred
*
self
.
std
+
self
.
mean
return
mask
,
y_pred
,
y_true
def
_reduce_scores
(
self
,
scores
,
reduction
=
'none'
):
"""Finalize the scores to return.
Parameters
----------
scores : list of float
Scores for all tasks.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if
reduction
==
'none'
:
return
scores
elif
reduction
==
'mean'
:
return
np
.
mean
(
scores
)
elif
reduction
==
'sum'
:
return
np
.
sum
(
scores
)
else
:
raise
ValueError
(
"Expect reduction to be 'none', 'mean' or 'sum', got {}"
.
format
(
reduction
))
def
multilabel_score
(
self
,
score_func
,
reduction
=
'none'
):
"""Evaluate for multi-label prediction.
Parameters
----------
score_func : callable
A score function that takes task-specific ground truth and predicted labels as
input and return a float as the score. The labels are in the form of 1D tensor.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
mask
,
y_pred
,
y_true
=
self
.
_finalize
()
n_tasks
=
y_true
.
shape
[
1
]
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
]
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
]
task_score
=
score_func
(
task_y_true
,
task_y_pred
)
if
task_score
is
not
None
:
scores
.
append
(
task_score
)
return
self
.
_reduce_scores
(
scores
,
reduction
)
def
pearson_r2
(
self
,
reduction
=
'none'
):
"""Compute squared Pearson correlation coefficient.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def
score
(
y_true
,
y_pred
):
return
pearsonr
(
y_true
.
numpy
(),
y_pred
.
numpy
())[
0
]
**
2
return
self
.
multilabel_score
(
score
,
reduction
)
def
mae
(
self
,
reduction
=
'none'
):
"""Compute mean absolute error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def
score
(
y_true
,
y_pred
):
return
F
.
l1_loss
(
y_true
,
y_pred
).
data
.
item
()
return
self
.
multilabel_score
(
score
,
reduction
)
def
rmse
(
self
,
reduction
=
'none'
):
"""Compute root mean square error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def
score
(
y_true
,
y_pred
):
return
np
.
sqrt
(
F
.
mse_loss
(
y_pred
,
y_true
).
cpu
().
item
())
return
self
.
multilabel_score
(
score
,
reduction
)
def
roc_auc_score
(
self
,
reduction
=
'none'
):
"""Compute roc-auc score for binary classification.
ROC-AUC scores are not well-defined in cases where labels for a task have one single
class only. In this case we will simply ignore this task and print a warning message.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
# Todo: This function only supports binary classification and we may need
# to support categorical classes.
assert
(
self
.
mean
is
None
)
and
(
self
.
std
is
None
),
\
'Label normalization should not be performed for binary classification.'
def
score
(
y_true
,
y_pred
):
if
len
(
y_true
.
unique
())
==
1
:
print
(
'Warning: Only one class {} present in y_true for a task. '
'ROC AUC score is not defined in that case.'
.
format
(
y_true
[
0
]))
return
None
else
:
return
roc_auc_score
(
y_true
.
long
().
numpy
(),
torch
.
sigmoid
(
y_pred
).
numpy
())
return
self
.
multilabel_score
(
score
,
reduction
)
def
compute_metric
(
self
,
metric_name
,
reduction
=
'none'
):
"""Compute metric based on metric name.
Parameters
----------
metric_name : str
* ``'r2'``: compute squared Pearson correlation coefficient
* ``'mae'``: compute mean absolute error
* ``'rmse'``: compute root mean square error
* ``'roc_auc_score'``: compute roc-auc score
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if
metric_name
==
'r2'
:
return
self
.
pearson_r2
(
reduction
)
elif
metric_name
==
'mae'
:
return
self
.
mae
(
reduction
)
elif
metric_name
==
'rmse'
:
return
self
.
rmse
(
reduction
)
elif
metric_name
==
'roc_auc_score'
:
return
self
.
roc_auc_score
(
reduction
)
else
:
raise
ValueError
(
'Expect metric_name to be "r2" or "mae" or "rmse" '
'or "roc_auc_score", got {}'
.
format
(
metric_name
))
apps/life_sci/python/dgllife/utils/featurizers.py
deleted
100644 → 0
View file @
94c67203
"""Node and edge featurization for molecular graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
itertools
import
os.path
as
osp
from
collections
import
defaultdict
from
functools
import
partial
from
rdkit
import
Chem
,
RDConfig
from
rdkit.Chem
import
AllChem
,
ChemicalFeatures
import
numpy
as
np
import
torch
import
dgl.backend
as
F
__all__
=
[
'one_hot_encoding'
,
'atom_type_one_hot'
,
'atomic_number_one_hot'
,
'atomic_number'
,
'atom_degree_one_hot'
,
'atom_degree'
,
'atom_total_degree_one_hot'
,
'atom_total_degree'
,
'atom_explicit_valence_one_hot'
,
'atom_explicit_valence'
,
'atom_implicit_valence_one_hot'
,
'atom_implicit_valence'
,
'atom_hybridization_one_hot'
,
'atom_total_num_H_one_hot'
,
'atom_total_num_H'
,
'atom_formal_charge_one_hot'
,
'atom_formal_charge'
,
'atom_num_radical_electrons_one_hot'
,
'atom_num_radical_electrons'
,
'atom_is_aromatic_one_hot'
,
'atom_is_aromatic'
,
'atom_is_in_ring_one_hot'
,
'atom_is_in_ring'
,
'atom_chiral_tag_one_hot'
,
'atom_mass'
,
'ConcatFeaturizer'
,
'BaseAtomFeaturizer'
,
'CanonicalAtomFeaturizer'
,
'WeaveAtomFeaturizer'
,
'PretrainAtomFeaturizer'
,
'bond_type_one_hot'
,
'bond_is_conjugated_one_hot'
,
'bond_is_conjugated'
,
'bond_is_in_ring_one_hot'
,
'bond_is_in_ring'
,
'bond_stereo_one_hot'
,
'bond_direction_one_hot'
,
'BaseBondFeaturizer'
,
'CanonicalBondFeaturizer'
,
'WeaveEdgeFeaturizer'
,
'PretrainBondFeaturizer'
]
def
one_hot_encoding
(
x
,
allowable_set
,
encode_unknown
=
False
):
"""One-hot encoding.
Parameters
----------
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
Examples
--------
>>> from dgllife.utils import one_hot_encoding
>>> one_hot_encoding('C', ['C', 'O'])
[True, False]
>>> one_hot_encoding('S', ['C', 'O'])
[False, False]
>>> one_hot_encoding('S', ['C', 'O'], encode_unknown=True)
[False, False, True]
"""
if
encode_unknown
and
(
allowable_set
[
-
1
]
is
not
None
):
allowable_set
.
append
(
None
)
if
encode_unknown
and
(
x
not
in
allowable_set
):
x
=
None
return
list
(
map
(
lambda
s
:
x
==
s
,
allowable_set
))
#################################################################
# Atom featurization
#################################################################
def
atom_type_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atomic_number_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
]
return
one_hot_encoding
(
atom
.
GetSymbol
(),
allowable_set
,
encode_unknown
)
def
atomic_number_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atom_type_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
101
))
return
one_hot_encoding
(
atom
.
GetAtomicNum
(),
allowable_set
,
encode_unknown
)
def
atomic_number
(
atom
):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atomic_number_one_hot
atom_type_one_hot
"""
return
[
atom
.
GetAtomicNum
()]
def
atom_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_degree
atom_total_degree
atom_total_degree_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
11
))
return
one_hot_encoding
(
atom
.
GetDegree
(),
allowable_set
,
encode_unknown
)
def
atom_degree
(
atom
):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_degree_one_hot
atom_total_degree
atom_total_degree_one_hot
"""
return
[
atom
.
GetDegree
()]
def
atom_total_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
one_hot_encoding
atom_degree
atom_degree_one_hot
atom_total_degree
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
6
))
return
one_hot_encoding
(
atom
.
GetTotalDegree
(),
allowable_set
,
encode_unknown
)
def
atom_total_degree
(
atom
):
"""The degree of an atom including Hs.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree_one_hot
atom_degree
atom_degree_one_hot
"""
return
[
atom
.
GetTotalDegree
()]
def
atom_explicit_valence_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the explicit valence of an aotm.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom explicit valences to consider. Default: ``1`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_explicit_valence
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
7
))
return
one_hot_encoding
(
atom
.
GetExplicitValence
(),
allowable_set
,
encode_unknown
)
def
atom_explicit_valence
(
atom
):
"""Get the explicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_explicit_valence_one_hot
"""
return
[
atom
.
GetExplicitValence
()]
def
atom_implicit_valence_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_implicit_valence
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
7
))
return
one_hot_encoding
(
atom
.
GetImplicitValence
(),
allowable_set
,
encode_unknown
)
def
atom_implicit_valence
(
atom
):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
See Also
--------
atom_implicit_valence_one_hot
"""
return
[
atom
.
GetImplicitValence
()]
# pylint: disable=I1101
def
atom_hybridization_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
,
Chem
.
rdchem
.
HybridizationType
.
SP3D
,
Chem
.
rdchem
.
HybridizationType
.
SP3D2
]
return
one_hot_encoding
(
atom
.
GetHybridization
(),
allowable_set
,
encode_unknown
)
def
atom_total_num_H_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_total_num_H
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetTotalNumHs
(),
allowable_set
,
encode_unknown
)
def
atom_total_num_H
(
atom
):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_num_H_one_hot
"""
return
[
atom
.
GetTotalNumHs
()]
def
atom_formal_charge_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_formal_charge
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
-
2
,
3
))
return
one_hot_encoding
(
atom
.
GetFormalCharge
(),
allowable_set
,
encode_unknown
)
def
atom_formal_charge
(
atom
):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_formal_charge_one_hot
"""
return
[
atom
.
GetFormalCharge
()]
def
atom_partial_charge
(
atom
):
"""Get Gasteiger partial charge for an atom.
For using this function, you must have called ``AllChem.ComputeGasteigerCharges(mol)``
to compute Gasteiger charges.
Occasionally, we can get nan or infinity Gasteiger charges, in which case we will set
the result to be 0.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one float only.
"""
gasteiger_charge
=
atom
.
GetProp
(
'_GasteigerCharge'
)
if
gasteiger_charge
in
[
'-nan'
,
'nan'
,
'-inf'
,
'inf'
]:
gasteiger_charge
=
0
return
[
float
(
gasteiger_charge
)]
def
atom_num_radical_electrons_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_num_radical_electrons
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetNumRadicalElectrons
(),
allowable_set
,
encode_unknown
)
def
atom_num_radical_electrons
(
atom
):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_num_radical_electrons_one_hot
"""
return
[
atom
.
GetNumRadicalElectrons
()]
def
atom_is_aromatic_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_aromatic
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
atom
.
GetIsAromatic
(),
allowable_set
,
encode_unknown
)
def
atom_is_aromatic
(
atom
):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
See Also
--------
atom_is_aromatic_one_hot
"""
return
[
atom
.
GetIsAromatic
()]
def
atom_is_in_ring_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_in_ring
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
atom
.
IsInRing
(),
allowable_set
,
encode_unknown
)
def
atom_is_in_ring
(
atom
):
"""Get whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
See Also
--------
atom_is_in_ring_one_hot
"""
return
[
atom
.
IsInRing
()]
def
atom_chiral_tag_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
Returns
-------
list
List containing one bool only.
See Also
--------
one_hot_encoding
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CW
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CCW
,
Chem
.
rdchem
.
ChiralType
.
CHI_OTHER
]
return
one_hot_encoding
(
atom
.
GetChiralTag
(),
allowable_set
,
encode_unknown
)
def
atom_mass
(
atom
,
coef
=
0.01
):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return
[
atom
.
GetMass
()
*
coef
]
class
ConcatFeaturizer
(
object
):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
def
__init__
(
self
,
func_list
):
self
.
func_list
=
func_list
def
__call__
(
self
,
x
):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return
list
(
itertools
.
chain
.
from_iterable
(
[
func
(
x
)
for
func
in
self
.
func_list
]))
class
BaseAtomFeaturizer
(
object
):
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes and a node i in the
graph corresponds to exactly atom i in the molecule.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgllife.utils import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
>>> # Get feature size for atom mass
>>> print(atom_featurizer.feat_size('mass'))
1
>>> # Get feature size for atom degree
>>> print(atom_featurizer.feat_size('degree'))
11
See Also
--------
CanonicalAtomFeaturizer
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
=
None
):
"""Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns
-------
int
Feature size for the feature with name ``feat_name``. Default to None.
"""
if
feat_name
is
None
:
assert
len
(
self
.
featurizer_funcs
)
==
1
,
\
'feat_name should be provided if there are more than one features'
feat_name
=
list
(
self
.
featurizer_funcs
.
keys
())[
0
]
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
if
feat_name
not
in
self
.
_feat_sizes
:
atom
=
Chem
.
MolFromSmiles
(
'C'
).
GetAtomWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
atom
))
return
self
.
_feat_sizes
[
feat_name
]
def
__call__
(
self
,
mol
):
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms
=
mol
.
GetNumAtoms
()
atom_features
=
defaultdict
(
list
)
# Compute features for each atom
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
atom_features
[
feat_name
].
append
(
feat_func
(
atom
))
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
atom_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
class
CanonicalAtomFeaturizer
(
BaseAtomFeaturizer
):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to 'h'.
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import CanonicalAtomFeaturizer
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
>>> atom_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])}
>>> # Get feature size for nodes
>>> print(atom_featurizer.feat_size('feat'))
74
See Also
--------
BaseAtomFeaturizer
"""
def
__init__
(
self
,
atom_data_field
=
'h'
):
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
atom_data_field
:
ConcatFeaturizer
(
[
atom_type_one_hot
,
atom_degree_one_hot
,
atom_implicit_valence_one_hot
,
atom_formal_charge
,
atom_num_radical_electrons
,
atom_hybridization_one_hot
,
atom_is_aromatic
,
atom_total_num_H_one_hot
]
)})
class
WeaveAtomFeaturizer
(
object
):
"""Atom featurizer in Weave.
The atom featurization performed in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__, which considers:
* atom types
* chirality
* formal charge
* partial charge
* aromatic atom
* hybridization
* hydrogen bond donor
* hydrogen bond acceptor
* the number of rings the atom belongs to for ring size between 3 and 8
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to 'h'.
atom_types : list of str or None
Atom types to consider for one-hot encoding. If None, we will use a default
choice of ``'H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'``.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``.
hybridization_types : list of Chem.rdchem.HybridizationType or None
Atom hybridization types to consider for one-hot encoding. If None, we will use a
default choice of ``Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3``.
"""
def
__init__
(
self
,
atom_data_field
=
'h'
,
atom_types
=
None
,
chiral_types
=
None
,
hybridization_types
=
None
):
super
(
WeaveAtomFeaturizer
,
self
).
__init__
()
self
.
_atom_data_field
=
atom_data_field
if
atom_types
is
None
:
atom_types
=
[
'H'
,
'C'
,
'N'
,
'O'
,
'F'
,
'P'
,
'S'
,
'Cl'
,
'Br'
,
'I'
]
self
.
_atom_types
=
atom_types
if
chiral_types
is
None
:
chiral_types
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CW
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CCW
]
self
.
_chiral_types
=
chiral_types
if
hybridization_types
is
None
:
hybridization_types
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
]
self
.
_hybridization_types
=
hybridization_types
self
.
_featurizer
=
ConcatFeaturizer
([
partial
(
atom_type_one_hot
,
allowable_set
=
atom_types
,
encode_unknown
=
True
),
partial
(
atom_chiral_tag_one_hot
,
allowable_set
=
chiral_types
),
atom_formal_charge
,
atom_partial_charge
,
atom_is_aromatic
,
partial
(
atom_hybridization_one_hot
,
allowable_set
=
hybridization_types
)
])
def
feat_size
(
self
):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol
=
Chem
.
MolFromSmiles
(
'C'
)
feats
=
self
(
mol
)[
self
.
_atom_data_field
]
return
feats
.
shape
[
-
1
]
def
get_donor_acceptor_info
(
self
,
mol_feats
):
"""Bookkeep whether an atom is donor/acceptor for hydrogen bonds.
Parameters
----------
mol_feats : tuple of rdkit.Chem.rdMolChemicalFeatures.MolChemicalFeature
Features for molecules.
Returns
-------
is_donor : dict
Mapping atom ids to binary values indicating whether atoms
are donors for hydrogen bonds
is_acceptor : dict
Mapping atom ids to binary values indicating whether atoms
are acceptors for hydrogen bonds
"""
is_donor
=
defaultdict
(
bool
)
is_acceptor
=
defaultdict
(
bool
)
# Get hydrogen bond donor/acceptor information
for
feats
in
mol_feats
:
if
feats
.
GetFamily
()
==
'Donor'
:
nodes
=
feats
.
GetAtomIds
()
for
u
in
nodes
:
is_donor
[
u
]
=
True
elif
feats
.
GetFamily
()
==
'Acceptor'
:
nodes
=
feats
.
GetAtomIds
()
for
u
in
nodes
:
is_acceptor
[
u
]
=
True
return
is_donor
,
is_acceptor
def
__call__
(
self
,
mol
):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping atom_data_field as specified in the input argument to the atom
features, which is a float32 tensor of shape (N, M), N is the number of
atoms and M is the feature size.
"""
atom_features
=
[]
AllChem
.
ComputeGasteigerCharges
(
mol
)
num_atoms
=
mol
.
GetNumAtoms
()
# Get information for donor and acceptor
fdef_name
=
osp
.
join
(
RDConfig
.
RDDataDir
,
'BaseFeatures.fdef'
)
mol_featurizer
=
ChemicalFeatures
.
BuildFeatureFactory
(
fdef_name
)
mol_feats
=
mol_featurizer
.
GetFeaturesForMol
(
mol
)
is_donor
,
is_acceptor
=
self
.
get_donor_acceptor_info
(
mol_feats
)
# Get a symmetrized smallest set of smallest rings
# Following the practice from Chainer Chemistry (https://github.com/chainer/
# chainer-chemistry/blob/da2507b38f903a8ee333e487d422ba6dcec49b05/chainer_chemistry/
# dataset/preprocessors/weavenet_preprocessor.py)
sssr
=
Chem
.
GetSymmSSSR
(
mol
)
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
# Features that can be computed directly from RDKit atom instances, which is a list
feats
=
self
.
_featurizer
(
atom
)
# Donor/acceptor indicator
feats
.
append
(
float
(
is_donor
[
i
]))
feats
.
append
(
float
(
is_acceptor
[
i
]))
# Count the number of rings the atom belongs to for ring size between 3 and 8
count
=
[
0
for
_
in
range
(
3
,
9
)]
for
ring
in
sssr
:
ring_size
=
len
(
ring
)
if
i
in
ring
and
3
<=
ring_size
<=
8
:
count
[
ring_size
-
3
]
+=
1
feats
.
extend
(
count
)
atom_features
.
append
(
feats
)
atom_features
=
np
.
stack
(
atom_features
)
return
{
self
.
_atom_data_field
:
F
.
zerocopy_from_numpy
(
atom_features
.
astype
(
np
.
float32
))}
class
PretrainAtomFeaturizer
(
object
):
"""AtomFeaturizer in Strategies for Pre-training Graph Neural Networks.
The atom featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* atomic number
* chirality
Parameters
----------
atomic_number_types : list of int or None
Atomic number types to consider for one-hot encoding. If None, we will use a default
choice of 1-118.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER``.
"""
def
__init__
(
self
,
atomic_number_types
=
None
,
chiral_types
=
None
):
if
atomic_number_types
is
None
:
atomic_number_types
=
list
(
range
(
1
,
119
))
self
.
_atomic_number_types
=
atomic_number_types
if
chiral_types
is
None
:
chiral_types
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CW
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CCW
,
Chem
.
rdchem
.
ChiralType
.
CHI_OTHER
]
self
.
_chiral_types
=
chiral_types
def
__call__
(
self
,
mol
):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'atomic_number' and 'chirality_type' to separately an int64 tensor
of shape (N, 1), N is the number of atoms
"""
atom_features
=
[]
num_atoms
=
mol
.
GetNumAtoms
()
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
atom_features
.
append
([
self
.
_atomic_number_types
.
index
(
atom
.
GetAtomicNum
()),
self
.
_chiral_types
.
index
(
atom
.
GetChiralTag
())
])
atom_features
=
np
.
stack
(
atom_features
)
atom_features
=
F
.
zerocopy_from_numpy
(
atom_features
.
astype
(
np
.
int64
))
return
{
'atomic_number'
:
atom_features
[:,
0
],
'chirality_type'
:
atom_features
[:,
1
]
}
def
bond_type_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the type of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
]
return
one_hot_encoding
(
bond
.
GetBondType
(),
allowable_set
,
encode_unknown
)
def
bond_is_conjugated_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_conjugated
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
GetIsConjugated
(),
allowable_set
,
encode_unknown
)
def
bond_is_conjugated
(
bond
):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_conjugated_one_hot
"""
return
[
bond
.
GetIsConjugated
()]
def
bond_is_in_ring_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_in_ring
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
IsInRing
(),
allowable_set
,
encode_unknown
)
def
bond_is_in_ring
(
bond
):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_in_ring_one_hot
"""
return
[
bond
.
IsInRing
()]
def
bond_stereo_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondStereo
.
STEREONONE
,
Chem
.
rdchem
.
BondStereo
.
STEREOANY
,
Chem
.
rdchem
.
BondStereo
.
STEREOZ
,
Chem
.
rdchem
.
BondStereo
.
STEREOE
,
Chem
.
rdchem
.
BondStereo
.
STEREOCIS
,
Chem
.
rdchem
.
BondStereo
.
STEREOTRANS
]
return
one_hot_encoding
(
bond
.
GetStereo
(),
allowable_set
,
encode_unknown
)
def
bond_direction_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the direction of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondDir
Bond directions to consider. Default: ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondDir
.
NONE
,
Chem
.
rdchem
.
BondDir
.
ENDUPRIGHT
,
Chem
.
rdchem
.
BondDir
.
ENDDOWNRIGHT
]
return
one_hot_encoding
(
bond
.
GetBondDir
(),
allowable_set
,
encode_unknown
)
class
BaseBondFeaturizer
(
object
):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgllife.utils import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'type': bond_type_one_hot, 'ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
4
>>> bond_featurizer.feat_size('ring')
1
See Also
--------
CanonicalBondFeaturizer
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
=
None
):
"""Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns
-------
int
Feature size for the feature with name ``feat_name``. Default to None.
"""
if
feat_name
is
None
:
assert
len
(
self
.
featurizer_funcs
)
==
1
,
\
'feat_name should be provided if there are more than one features'
feat_name
=
list
(
self
.
featurizer_funcs
.
keys
())[
0
]
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
if
feat_name
not
in
self
.
_feat_sizes
:
bond
=
Chem
.
MolFromSmiles
(
'CO'
).
GetBondWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
bond
))
return
self
.
_feat_sizes
[
feat_name
]
def
__call__
(
self
,
mol
):
"""Featurize all bonds in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_bonds
=
mol
.
GetNumBonds
()
bond_features
=
defaultdict
(
list
)
# Compute features for each bond
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
feat
=
feat_func
(
bond
)
bond_features
[
feat_name
].
extend
([
feat
,
feat
.
copy
()])
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
bond_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
class
CanonicalBondFeaturizer
(
BaseBondFeaturizer
):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Examples
--------
>>> from dgllife.utils import CanonicalBondFeaturizer
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = CanonicalBondFeaturizer(bond_data_field='feat')
>>> bond_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
12
See Also
--------
BaseBondFeaturizer
"""
def
__init__
(
self
,
bond_data_field
=
'e'
):
super
(
CanonicalBondFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
bond_data_field
:
ConcatFeaturizer
(
[
bond_type_one_hot
,
bond_is_conjugated
,
bond_is_in_ring
,
bond_stereo_one_hot
]
)})
# pylint: disable=E1102
class
WeaveEdgeFeaturizer
(
object
):
"""Edge featurizer in Weave.
The edge featurization is introduced in `Molecular Graph Convolutions:
Moving Beyond Fingerprints <https://arxiv.org/abs/1603.00856>`__.
This featurization is performed for a complete graph of atoms with self loops added,
which considers:
* Number of bonds between each pairs of atoms
* One-hot encoding of bond type if a bond exists between a pair of atoms
* Whether a pair of atoms belongs to a same ring
Parameters
----------
edge_data_field : str
Name for storing edge features in DGLGraphs, default to ``'e'``.
max_distance : int
Maximum number of bonds to consider between each pair of atoms.
Default to 7.
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider for one hot encoding. If None, we consider by
default single, double, triple and aromatic bonds.
"""
def
__init__
(
self
,
edge_data_field
=
'e'
,
max_distance
=
7
,
bond_types
=
None
):
super
(
WeaveEdgeFeaturizer
,
self
).
__init__
()
self
.
_edge_data_field
=
edge_data_field
self
.
_max_distance
=
max_distance
if
bond_types
is
None
:
bond_types
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
]
self
.
_bond_types
=
bond_types
def
feat_size
(
self
):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol
=
Chem
.
MolFromSmiles
(
'C'
)
feats
=
self
(
mol
)[
self
.
_edge_data_field
]
return
feats
.
shape
[
-
1
]
def
__call__
(
self
,
mol
):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping self._edge_data_field to a float32 tensor of shape (N, M), where
N is the number of atom pairs and M is the feature size.
"""
# Part 1 based on number of bonds between each pair of atoms
distance_matrix
=
torch
.
from_numpy
(
Chem
.
GetDistanceMatrix
(
mol
))
# Change shape from (V, V, 1) to (V^2, 1)
distance_matrix
=
distance_matrix
.
float
().
reshape
(
-
1
,
1
)
# Elementwise compare if distance is bigger than 0, 1, ..., max_distance - 1
distance_indicators
=
(
distance_matrix
>
torch
.
arange
(
0
,
self
.
_max_distance
).
float
()).
float
()
# Part 2 for one hot encoding of bond type.
num_atoms
=
mol
.
GetNumAtoms
()
bond_indicators
=
torch
.
zeros
(
num_atoms
,
num_atoms
,
len
(
self
.
_bond_types
))
for
bond
in
mol
.
GetBonds
():
bond_type_encoding
=
torch
.
tensor
(
bond_type_one_hot
(
bond
,
allowable_set
=
self
.
_bond_types
)).
float
()
begin_atom_idx
,
end_atom_idx
=
bond
.
GetBeginAtomIdx
(),
bond
.
GetEndAtomIdx
()
bond_indicators
[
begin_atom_idx
,
end_atom_idx
]
=
bond_type_encoding
bond_indicators
[
end_atom_idx
,
begin_atom_idx
]
=
bond_type_encoding
# Reshape from (V, V, num_bond_types) to (V^2, num_bond_types)
bond_indicators
=
bond_indicators
.
reshape
(
-
1
,
len
(
self
.
_bond_types
))
# Part 3 for whether a pair of atoms belongs to a same ring.
sssr
=
Chem
.
GetSymmSSSR
(
mol
)
ring_mate_indicators
=
torch
.
zeros
(
num_atoms
,
num_atoms
,
1
)
for
ring
in
sssr
:
ring
=
list
(
ring
)
num_atoms_in_ring
=
len
(
ring
)
for
i
in
range
(
num_atoms_in_ring
):
ring_mate_indicators
[
ring
[
i
],
torch
.
tensor
(
ring
)]
=
1
ring_mate_indicators
=
ring_mate_indicators
.
reshape
(
-
1
,
1
)
return
{
self
.
_edge_data_field
:
torch
.
cat
([
distance_indicators
,
bond_indicators
,
ring_mate_indicators
],
dim
=
1
)}
class
PretrainBondFeaturizer
(
object
):
"""BondFeaturizer in Strategies for Pre-training Graph Neural Networks.
The bond featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* bond type
* bond direction
Parameters
----------
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider. Default to ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
bond_direction_types : list of Chem.rdchem.BondDir or None
Bond directions to consider. Default to ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
self_loop : bool
Whether self loops will be added. Default to True.
"""
def
__init__
(
self
,
bond_types
=
None
,
bond_direction_types
=
None
,
self_loop
=
True
):
if
bond_types
is
None
:
bond_types
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
]
self
.
_bond_types
=
bond_types
if
bond_direction_types
is
None
:
bond_direction_types
=
[
Chem
.
rdchem
.
BondDir
.
NONE
,
Chem
.
rdchem
.
BondDir
.
ENDUPRIGHT
,
Chem
.
rdchem
.
BondDir
.
ENDDOWNRIGHT
]
self
.
_bond_direction_types
=
bond_direction_types
self
.
_self_loop
=
self_loop
def
__call__
(
self
,
mol
):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'bond_type' and 'bond_direction_type' separately to an int64
tensor of shape (N, 1), where N is the number of edges.
"""
edge_features
=
[]
num_bonds
=
mol
.
GetNumBonds
()
# Compute features for each bond
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
bond_feats
=
[
self
.
_bond_types
.
index
(
bond
.
GetBondType
()),
self
.
_bond_direction_types
.
index
(
bond
.
GetBondDir
())
]
edge_features
.
extend
([
bond_feats
,
bond_feats
.
copy
()])
if
self
.
_self_loop
:
self_loop_features
=
torch
.
zeros
((
mol
.
GetNumAtoms
(),
2
),
dtype
=
torch
.
int64
)
self_loop_features
[:,
0
]
=
len
(
self
.
_bond_types
)
if
num_bonds
==
0
:
edge_features
=
self_loop_features
else
:
edge_features
=
np
.
stack
(
edge_features
)
edge_features
=
F
.
zerocopy_from_numpy
(
edge_features
.
astype
(
np
.
int64
))
edge_features
=
torch
.
cat
([
edge_features
,
self_loop_features
],
dim
=
0
)
return
{
'bond_type'
:
edge_features
[:,
0
],
'bond_direction_type'
:
edge_features
[:,
1
]}
apps/life_sci/python/dgllife/utils/mol_to_graph.py
deleted
100644 → 0
View file @
94c67203
"""Convert molecules into DGLGraphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
from
functools
import
partial
import
torch
from
dgl
import
DGLGraph
from
rdkit
import
Chem
from
rdkit.Chem
import
rdmolfiles
,
rdmolops
from
sklearn.neighbors
import
NearestNeighbors
__all__
=
[
'mol_to_graph'
,
'smiles_to_bigraph'
,
'mol_to_bigraph'
,
'smiles_to_complete_graph'
,
'mol_to_complete_graph'
,
'k_nearest_neighbors'
,
'mol_to_nearest_neighbor_graph'
,
'smiles_to_nearest_neighbor_graph'
]
# pylint: disable=I1101
def
mol_to_graph
(
mol
,
graph_constructor
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
This function can be used to construct any arbitrary ``DGLGraph`` from an
RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
See Also
--------
mol_to_bigraph
mol_to_complete_graph
mol_to_nearest_neighbor_graph
"""
if
canonical_atom_order
:
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
mol
=
rdmolops
.
RenumberAtoms
(
mol
,
new_order
)
g
=
graph_constructor
(
mol
)
if
node_featurizer
is
not
None
:
g
.
ndata
.
update
(
node_featurizer
(
mol
))
if
edge_featurizer
is
not
None
:
g
.
edata
.
update
(
edge_featurizer
(
mol
))
return
g
def
construct_bigraph_from_mol
(
mol
,
add_self_loop
=
False
):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g
=
DGLGraph
()
# Add nodes
num_atoms
=
mol
.
GetNumAtoms
()
g
.
add_nodes
(
num_atoms
)
# Add edges
src_list
=
[]
dst_list
=
[]
num_bonds
=
mol
.
GetNumBonds
()
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
src_list
.
extend
([
u
,
v
])
dst_list
.
extend
([
v
,
u
])
g
.
add_edges
(
src_list
,
dst_list
)
if
add_self_loop
:
nodes
=
g
.
nodes
()
g
.
add_edges
(
nodes
,
nodes
)
return
g
def
mol_to_bigraph
(
mol
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol, node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
smiles_to_bigraph
"""
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
def
smiles_to_bigraph
(
smiles
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_bigraph
>>> g = smiles_to_bigraph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> g = smiles_to_bigraph('CCO', node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
mol_to_bigraph
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_bigraph
(
mol
,
add_self_loop
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
def
construct_complete_graph_from_mol
(
mol
,
add_self_loop
=
False
):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
num_atoms
=
mol
.
GetNumAtoms
()
edge_list
=
[]
for
i
in
range
(
num_atoms
):
for
j
in
range
(
num_atoms
):
if
i
!=
j
or
add_self_loop
:
edge_list
.
append
((
i
,
j
))
g
=
DGLGraph
(
edge_list
)
return
g
def
mol_to_complete_graph
(
mol
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_complete_graph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> add_self_loop = True
>>> g = mol_to_complete_graph(
>>> mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
smiles_to_complete_graph
"""
return
mol_to_graph
(
mol
,
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
def
smiles_to_complete_graph
(
smiles
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_complete_graph
>>> g = smiles_to_complete_graph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> add_self_loop = True
>>> g = smiles_to_complete_graph(
>>> 'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
mol_to_complete_graph
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
)
def
k_nearest_neighbors
(
coordinates
,
neighbor_cutoff
,
max_num_neighbors
=
None
,
p_distance
=
2
,
self_loops
=
False
):
"""Find k nearest neighbors for each atom
We do not guarantee that the edges are sorted according to the distance
between atoms.
Parameters
----------
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
self_loops : bool
Whether to allow a node to be its own neighbor. Default to False.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes, corresponding to ``srcs``.
distances : list of float
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Examples
--------
>>> from dgllife.utils import get_mol_3d_coordinates, k_nearest_neighbors
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> srcs, dsts, dists = k_nearest_neighbors(coords, neighbor_cutoff=1.25)
>>> print(srcs)
[8, 7, 11, 10, 20, 19]
>>> print(dsts)
[7, 8, 10, 11, 19, 20]
>>> print(dists)
[1.2084666104583117, 1.2084666104583117, 1.226457824344217,
1.226457824344217, 1.2230522248065987, 1.2230522248065987]
See Also
--------
get_mol_3d_coordinates
mol_to_nearest_neighbor_graph
smiles_to_nearest_neighbor_graph
"""
num_atoms
=
coordinates
.
shape
[
0
]
model
=
NearestNeighbors
(
radius
=
neighbor_cutoff
,
p
=
p_distance
)
model
.
fit
(
coordinates
)
dists_
,
nbrs
=
model
.
radius_neighbors
(
coordinates
)
srcs
,
dsts
,
dists
=
[],
[],
[]
for
i
in
range
(
num_atoms
):
dists_i
=
dists_
[
i
].
tolist
()
nbrs_i
=
nbrs
[
i
].
tolist
()
if
not
self_loops
:
dists_i
.
remove
(
0
)
nbrs_i
.
remove
(
i
)
if
max_num_neighbors
is
not
None
and
len
(
nbrs_i
)
>
max_num_neighbors
:
packed_nbrs
=
list
(
zip
(
dists_i
,
nbrs_i
))
# Sort neighbors based on distance from smallest to largest
packed_nbrs
.
sort
(
key
=
lambda
tup
:
tup
[
0
])
dists_i
,
nbrs_i
=
map
(
list
,
zip
(
*
packed_nbrs
))
dsts
.
extend
([
i
for
_
in
range
(
max_num_neighbors
)])
srcs
.
extend
(
nbrs_i
[:
max_num_neighbors
])
dists
.
extend
(
dists_i
[:
max_num_neighbors
])
else
:
dsts
.
extend
([
i
for
_
in
range
(
len
(
nbrs_i
))])
srcs
.
extend
(
nbrs_i
)
dists
.
extend
(
dists_i
)
return
srcs
,
dsts
,
dists
# pylint: disable=E1102
def
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
,
max_num_neighbors
=
None
,
p_distance
=
2
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
,
keep_dists
=
False
,
dist_field
=
'dist'
):
"""Convert an RDKit molecule into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import mol_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
smiles_to_nearest_neighbor_graph
"""
if
canonical_atom_order
:
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
mol
=
rdmolops
.
RenumberAtoms
(
mol
,
new_order
)
srcs
,
dsts
,
dists
=
k_nearest_neighbors
(
coordinates
=
coordinates
,
neighbor_cutoff
=
neighbor_cutoff
,
max_num_neighbors
=
max_num_neighbors
,
p_distance
=
p_distance
,
self_loops
=
add_self_loop
)
g
=
DGLGraph
()
# Add nodes first since some nodes may be completely isolated
num_atoms
=
mol
.
GetNumAtoms
()
g
.
add_nodes
(
num_atoms
)
# Add edges
g
.
add_edges
(
srcs
,
dsts
)
if
node_featurizer
is
not
None
:
g
.
ndata
.
update
(
node_featurizer
(
mol
))
if
edge_featurizer
is
not
None
:
g
.
edata
.
update
(
edge_featurizer
(
mol
))
if
keep_dists
:
assert
dist_field
not
in
g
.
edata
,
\
'Expect {} to be reserved for distance between neighboring atoms.'
g
.
edata
[
dist_field
]
=
torch
.
tensor
(
dists
).
float
().
reshape
(
-
1
,
1
)
return
g
def
smiles_to_nearest_neighbor_graph
(
smiles
,
coordinates
,
neighbor_cutoff
,
max_num_neighbors
=
None
,
p_distance
=
2
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
canonical_atom_order
=
True
,
keep_dists
=
False
,
dist_field
=
'dist'
):
"""Convert a SMILES into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
smiles : str
String of SMILES
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> smiles = 'CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
>>> mol = Chem.MolFromSmiles(smiles)
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = smiles_to_nearest_neighbor_graph(smiles, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
mol_to_nearest_neighbor_graph
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_nearest_neighbor_graph
(
mol
,
coordinates
,
neighbor_cutoff
,
max_num_neighbors
,
p_distance
,
add_self_loop
,
node_featurizer
,
edge_featurizer
,
canonical_atom_order
,
keep_dists
,
dist_field
)
apps/life_sci/python/dgllife/utils/rdkit_utils.py
deleted
100644 → 0
View file @
94c67203
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
warnings
from
functools
import
partial
from
multiprocessing
import
Pool
from
rdkit
import
Chem
from
rdkit.Chem
import
AllChem
__all__
=
[
'get_mol_3d_coordinates'
,
'load_molecule'
,
'multiprocess_load_molecules'
]
# pylint: disable=W0702
def
get_mol_3d_coordinates
(
mol
):
"""Get 3D coordinates of the molecule.
This function requires that molecular conformation has been initialized.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
Examples
--------
An error will occur in the example below since the molecule object does not
carry conformation information.
>>> from rdkit import Chem
>>> from dgllife.utils import get_mol_3d_coordinates
>>> mol = Chem.MolFromSmiles('CCO')
Below we give a working example based on molecule conformation initialized from calculation.
>>> from rdkit.Chem import AllChem
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> print(coords)
array([[ 1.20967478, -0.25802181, 0. ],
[-0.05021255, 0.57068079, 0. ],
[-1.15946223, -0.31265898, 0. ]])
"""
try
:
conf
=
mol
.
GetConformer
()
conf_num_atoms
=
conf
.
GetNumAtoms
()
mol_num_atoms
=
mol
.
GetNumAtoms
()
assert
mol_num_atoms
==
conf_num_atoms
,
\
'Expect the number of atoms in the molecule and its conformation '
\
'to be the same, got {:d} and {:d}'
.
format
(
mol_num_atoms
,
conf_num_atoms
)
return
conf
.
GetPositions
()
except
:
warnings
.
warn
(
'Unable to get conformation of the molecule.'
)
return
None
# pylint: disable=E1101
def
load_molecule
(
molecule_file
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
):
"""Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
or ``.pdbqt`` or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if
molecule_file
.
endswith
(
'.mol2'
):
mol
=
Chem
.
MolFromMol2File
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
elif
molecule_file
.
endswith
(
'.sdf'
):
supplier
=
Chem
.
SDMolSupplier
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
mol
=
supplier
[
0
]
elif
molecule_file
.
endswith
(
'.pdbqt'
):
with
open
(
molecule_file
)
as
file
:
pdbqt_data
=
file
.
readlines
()
pdb_block
=
''
for
line
in
pdbqt_data
:
pdb_block
+=
'{}
\n
'
.
format
(
line
[:
66
])
mol
=
Chem
.
MolFromPDBBlock
(
pdb_block
,
sanitize
=
False
,
removeHs
=
False
)
elif
molecule_file
.
endswith
(
'.pdb'
):
mol
=
Chem
.
MolFromPDBFile
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
else
:
return
ValueError
(
'Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'
.
format
(
molecule_file
))
try
:
if
sanitize
or
calc_charges
:
Chem
.
SanitizeMol
(
mol
)
if
calc_charges
:
# Compute Gasteiger charges on the molecule.
try
:
AllChem
.
ComputeGasteigerCharges
(
mol
)
except
:
warnings
.
warn
(
'Unable to compute charges for the molecule.'
)
if
remove_hs
:
mol
=
Chem
.
RemoveHs
(
mol
)
except
:
return
None
,
None
if
use_conformation
:
coordinates
=
get_mol_3d_coordinates
(
mol
)
else
:
coordinates
=
None
return
mol
,
coordinates
def
multiprocess_load_molecules
(
files
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
,
num_processes
=
2
):
"""Load molecules from files with multiprocessing, which can be of format ``.mol2`` or
``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format ``.mol2``,
``.sdf``, ``.pdbqt``, or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if
num_processes
==
1
:
mols_loaded
=
[]
for
f
in
files
:
mols_loaded
.
append
(
load_molecule
(
f
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
))
else
:
with
Pool
(
processes
=
num_processes
)
as
pool
:
mols_loaded
=
pool
.
map_async
(
partial
(
load_molecule
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
),
files
)
mols_loaded
=
mols_loaded
.
get
()
return
mols_loaded
apps/life_sci/python/dgllife/utils/splitters.py
deleted
100644 → 0
View file @
94c67203
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable=E0611
from
collections
import
defaultdict
from
functools
import
partial
from
itertools
import
accumulate
,
chain
from
rdkit
import
Chem
from
rdkit.Chem
import
rdMolDescriptors
from
rdkit.Chem.rdmolops
import
FastFindRings
from
rdkit.Chem.Scaffolds
import
MurckoScaffold
import
dgl.backend
as
F
import
numpy
as
np
from
dgl.data.utils
import
split_dataset
,
Subset
__all__
=
[
'ConsecutiveSplitter'
,
'RandomSplitter'
,
'MolecularWeightSplitter'
,
'ScaffoldSplitter'
,
'SingleTaskStratifiedSplitter'
]
def
base_k_fold_split
(
split_method
,
dataset
,
k
,
log
):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set),
which are all :class:`Subset` instances.
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
all_folds
=
[]
frac_per_part
=
1.
/
k
for
i
in
range
(
k
):
if
log
:
print
(
'Processing fold {:d}/{:d}'
.
format
(
i
+
1
,
k
))
# We are reusing the code for train-validation-test split.
train_set1
,
val_set
,
train_set2
=
split_method
(
dataset
,
frac_train
=
i
*
frac_per_part
,
frac_val
=
frac_per_part
,
frac_test
=
1.
-
(
i
+
1
)
*
frac_per_part
)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set
=
Subset
(
dataset
,
np
.
concatenate
(
[
train_set1
.
indices
,
train_set2
.
indices
]).
astype
(
np
.
int64
))
all_folds
.
append
((
train_set
,
val_set
))
return
all_folds
def
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction
=
frac_train
+
frac_val
+
frac_test
assert
np
.
allclose
(
total_fraction
,
1.
),
\
'Expect the sum of fractions for training, validation and '
\
'test to be 1, got {:.4f}'
.
format
(
total_fraction
)
def
indices_split
(
dataset
,
frac_train
,
frac_val
,
frac_test
,
indices
):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list
=
np
.
array
([
frac_train
,
frac_val
,
frac_test
])
assert
np
.
allclose
(
np
.
sum
(
frac_list
),
1.
),
\
'Expect frac_list sum to 1, got {:.4f}'
.
format
(
np
.
sum
(
frac_list
))
num_data
=
len
(
dataset
)
lengths
=
(
num_data
*
frac_list
).
astype
(
int
)
lengths
[
-
1
]
=
num_data
-
np
.
sum
(
lengths
[:
-
1
])
return
[
Subset
(
dataset
,
list
(
indices
[
offset
-
length
:
offset
]))
for
offset
,
length
in
zip
(
accumulate
(
lengths
),
lengths
)]
def
count_and_log
(
message
,
i
,
total
,
log_every_n
):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if
(
log_every_n
is
not
None
)
and
((
i
+
1
)
%
log_every_n
==
0
):
print
(
'{} {:d}/{:d}'
.
format
(
message
,
i
+
1
,
total
))
def
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
=
1000
):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if
mols
is
not
None
:
# Sanity check
assert
len
(
mols
)
==
len
(
dataset
),
\
'Expect mols to be of the same size as that of the dataset, '
\
'got {:d} and {:d}'
.
format
(
len
(
mols
),
len
(
dataset
))
else
:
if
log_every_n
is
not
None
:
print
(
'Start initializing RDKit molecule instances...'
)
mols
=
[]
for
i
,
s
in
enumerate
(
dataset
.
smiles
):
count_and_log
(
'Creating RDKit molecule instance'
,
i
,
len
(
dataset
.
smiles
),
log_every_n
)
mols
.
append
(
Chem
.
MolFromSmiles
(
s
,
sanitize
=
sanitize
))
return
mols
class
ConsecutiveSplitter
(
object
):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@
staticmethod
def
train_val_test_split
(
dataset
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test that also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
False
)
@
staticmethod
def
k_fold_split
(
dataset
,
k
=
5
,
log
=
True
):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
return
base_k_fold_split
(
ConsecutiveSplitter
.
train_val_test_split
,
dataset
,
k
,
log
)
class
RandomSplitter
(
object
):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@
staticmethod
def
train_val_test_split
(
dataset
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
random_state
=
None
):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors.
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
True
,
random_state
=
random_state
)
@
staticmethod
def
k_fold_split
(
dataset
,
k
=
5
,
random_state
=
None
,
log
=
True
):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices
=
np
.
random
.
RandomState
(
seed
=
random_state
).
permutation
(
len
(
dataset
))
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
indices
),
dataset
,
k
,
log
)
# pylint: disable=I1101
class
MolecularWeightSplitter
(
object
):
"""Sort molecules based on their weights and then split them."""
@
staticmethod
def
molecular_weight_indices
(
molecules
,
log_every_n
):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if
log_every_n
is
not
None
:
print
(
'Start computing molecular weights.'
)
mws
=
[]
for
i
,
mol
in
enumerate
(
molecules
):
count_and_log
(
'Computing molecular weight for compound'
,
i
,
len
(
molecules
),
log_every_n
)
mws
.
append
(
rdMolDescriptors
.
CalcExactMolWt
(
mol
))
return
np
.
argsort
(
mws
)
@
staticmethod
def
train_val_test_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
log_every_n
=
1000
):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
return
indices_split
(
dataset
,
frac_train
,
frac_val
,
frac_test
,
sorted_indices
)
@
staticmethod
def
k_fold_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
k
=
5
,
log_every_n
=
1000
):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
sorted_indices
),
dataset
,
k
,
log
=
(
log_every_n
is
not
None
))
# pylint: disable=W0702
class
ScaffoldSplitter
(
object
):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@
staticmethod
def
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if
log_every_n
is
not
None
:
print
(
'Start computing Bemis-Murcko scaffolds.'
)
scaffolds
=
defaultdict
(
list
)
for
i
,
mol
in
enumerate
(
molecules
):
count_and_log
(
'Computing Bemis-Murcko for compound'
,
i
,
len
(
molecules
),
log_every_n
)
# For mols that have not been sanitized, we need to compute their ring information
try
:
FastFindRings
(
mol
)
mol_scaffold
=
MurckoScaffold
.
MurckoScaffoldSmiles
(
mol
=
mol
,
includeChirality
=
include_chirality
)
# Group molecules that have the same scaffold
scaffolds
[
mol_scaffold
].
append
(
i
)
except
:
print
(
'Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'
.
format
(
i
+
1
))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets
=
[
scaffold_set
for
(
scaffold
,
scaffold_set
)
in
sorted
(
scaffolds
.
items
(),
key
=
lambda
x
:
(
len
(
x
[
1
]),
x
[
1
][
0
]),
reverse
=
True
)
]
return
scaffold_sets
@
staticmethod
def
train_val_test_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
include_chirality
=
False
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
log_every_n
=
1000
):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
)
scaffold_sets
=
ScaffoldSplitter
.
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
)
train_indices
,
val_indices
,
test_indices
=
[],
[],
[]
train_cutoff
=
int
(
frac_train
*
len
(
molecules
))
val_cutoff
=
int
((
frac_train
+
frac_val
)
*
len
(
molecules
))
for
group_indices
in
scaffold_sets
:
if
len
(
train_indices
)
+
len
(
group_indices
)
>
train_cutoff
:
if
len
(
train_indices
)
+
len
(
val_indices
)
+
len
(
group_indices
)
>
val_cutoff
:
test_indices
.
extend
(
group_indices
)
else
:
val_indices
.
extend
(
group_indices
)
else
:
train_indices
.
extend
(
group_indices
)
return
[
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
),
Subset
(
dataset
,
test_indices
)]
@
staticmethod
def
k_fold_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
include_chirality
=
False
,
k
=
5
,
log_every_n
=
1000
):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
)
scaffold_sets
=
ScaffoldSplitter
.
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
)
# k buckets that form a relatively balanced partition of the dataset
index_buckets
=
[[]
for
_
in
range
(
k
)]
for
group_indices
in
scaffold_sets
:
bucket_chosen
=
int
(
np
.
argmin
([
len
(
bucket
)
for
bucket
in
index_buckets
]))
index_buckets
[
bucket_chosen
].
extend
(
group_indices
)
all_folds
=
[]
for
i
in
range
(
k
):
if
log_every_n
is
not
None
:
print
(
'Processing fold {:d}/{:d}'
.
format
(
i
+
1
,
k
))
train_indices
=
list
(
chain
.
from_iterable
(
index_buckets
[:
i
]
+
index_buckets
[
i
+
1
:]))
val_indices
=
index_buckets
[
i
]
all_folds
.
append
((
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
)))
return
all_folds
class
SingleTaskStratifiedSplitter
(
object
):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@
staticmethod
def
train_val_test_split
(
dataset
,
labels
,
task_id
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
bucket_size
=
10
,
random_state
=
None
):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
if
random_state
is
not
None
:
np
.
random
.
seed
(
random_state
)
if
not
isinstance
(
labels
,
np
.
ndarray
):
labels
=
F
.
asnumpy
(
labels
)
task_labels
=
labels
[:,
task_id
]
sorted_indices
=
np
.
argsort
(
task_labels
)
train_bucket_cutoff
=
int
(
np
.
round
(
frac_train
*
bucket_size
))
val_bucket_cutoff
=
int
(
np
.
round
(
frac_val
*
bucket_size
))
+
train_bucket_cutoff
train_indices
,
val_indices
,
test_indices
=
[],
[],
[]
while
sorted_indices
.
shape
[
0
]
>=
bucket_size
:
current_batch
,
sorted_indices
=
np
.
split
(
sorted_indices
,
[
bucket_size
])
shuffled
=
np
.
random
.
permutation
(
range
(
bucket_size
))
train_indices
.
extend
(
current_batch
[
shuffled
[:
train_bucket_cutoff
]].
tolist
())
val_indices
.
extend
(
current_batch
[
shuffled
[
train_bucket_cutoff
:
val_bucket_cutoff
]].
tolist
())
test_indices
.
extend
(
current_batch
[
shuffled
[
val_bucket_cutoff
:]].
tolist
())
# Place rest samples in the training set.
train_indices
.
extend
(
sorted_indices
.
tolist
())
return
[
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
),
Subset
(
dataset
,
test_indices
)]
@
staticmethod
def
k_fold_split
(
dataset
,
labels
,
task_id
,
k
=
5
,
log
=
True
):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
if
not
isinstance
(
labels
,
np
.
ndarray
):
labels
=
F
.
asnumpy
(
labels
)
task_labels
=
labels
[:,
task_id
]
sorted_indices
=
np
.
argsort
(
task_labels
).
tolist
()
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
sorted_indices
),
dataset
,
k
,
log
)
apps/life_sci/python/setup.py
deleted
100644 → 0
View file @
94c67203
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
os
from
setuptools
import
find_packages
from
setuptools
import
setup
CURRENT_DIR
=
os
.
path
.
dirname
(
__file__
)
def
get_lib_path
():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
libinfo_py
=
os
.
path
.
join
(
CURRENT_DIR
,
'./dgllife/libinfo.py'
)
libinfo
=
{
'__file__'
:
libinfo_py
}
exec
(
compile
(
open
(
libinfo_py
,
"rb"
).
read
(),
libinfo_py
,
'exec'
),
libinfo
,
libinfo
)
version
=
libinfo
[
'__version__'
]
return
version
VERSION
=
get_lib_path
()
setup
(
name
=
'dgllife'
,
version
=
VERSION
,
description
=
'DGL-based package for Life Science'
,
keywords
=
[
'pytorch'
,
'dgl'
,
'graph-neural-networks'
,
'life-science'
,
'drug-discovery'
],
maintainer
=
'DGL Team'
,
packages
=
[
package
for
package
in
find_packages
()
if
package
.
startswith
(
'dgllife'
)],
install_requires
=
[
'scikit-learn>=0.22.2'
,
'pandas'
,
'requests>=2.22.0'
,
'tqdm'
,
'numpy>=1.14.0'
,
'scipy>=1.1.0'
,
'networkx>=2.1'
,
],
url
=
'https://github.com/dmlc/dgl/tree/master/apps/life_sci'
,
classifiers
=
[
'Development Status :: 3 - Alpha'
,
'Programming Language :: Python :: 3'
,
'License :: OSI Approved :: Apache Software License'
],
license
=
'APACHE'
)
apps/life_sci/python/update_version.py
deleted
100644 → 0
View file @
94c67203
"""
This is the global script that set the version information of DGL-LifeSci.
This script runs and update all the locations that related to versions
List of affected files:
- app-root/python/dgllife/__init__.py
- app-root/conda/dgllife/meta.yaml
"""
import
os
import
re
__version__
=
"0.2.2"
print
(
__version__
)
# Implementations
def
update
(
file_name
,
pattern
,
repl
):
update
=
[]
hit_counter
=
0
need_update
=
False
for
l
in
open
(
file_name
):
result
=
re
.
findall
(
pattern
,
l
)
if
result
:
assert
len
(
result
)
==
1
hit_counter
+=
1
if
result
[
0
]
!=
repl
:
l
=
re
.
sub
(
pattern
,
repl
,
l
)
need_update
=
True
print
(
"%s: %s->%s"
%
(
file_name
,
result
[
0
],
repl
))
else
:
print
(
"%s: version is already %s"
%
(
file_name
,
repl
))
update
.
append
(
l
)
if
hit_counter
!=
1
:
raise
RuntimeError
(
"Cannot find version in %s"
%
file_name
)
if
need_update
:
with
open
(
file_name
,
"w"
)
as
output_file
:
for
l
in
update
:
output_file
.
write
(
l
)
def
main
():
curr_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
__file__
)))
proj_root
=
os
.
path
.
abspath
(
os
.
path
.
join
(
curr_dir
,
".."
))
# python path
update
(
os
.
path
.
join
(
proj_root
,
"python/dgllife/libinfo.py"
),
r
"(?<=__version__ = \")[.0-9a-z]+"
,
__version__
)
# conda
update
(
os
.
path
.
join
(
proj_root
,
"conda/dgllife/meta.yaml"
),
"(?<=version:
\"
)[.0-9a-z]+"
,
__version__
)
if
__name__
==
'__main__'
:
main
()
apps/life_sci/tests/data/test_csv_dataset.py
deleted
100644 → 0
View file @
94c67203
import
os
import
pandas
as
pd
from
dgllife.data.csv_dataset
import
*
from
dgllife.utils.featurizers
import
*
from
dgllife.utils.mol_to_graph
import
*
def
test_data_frame
():
data
=
[[
'CCO'
,
0
,
1
],
[
'CO'
,
2
,
3
]]
df
=
pd
.
DataFrame
(
data
,
columns
=
[
'smiles'
,
'task1'
,
'task2'
])
return
df
def
remove_file
(
fname
):
if
os
.
path
.
isfile
(
fname
):
try
:
os
.
remove
(
fname
)
except
OSError
:
pass
def
test_mol_csv
():
df
=
test_data_frame
()
fname
=
'test.bin'
dataset
=
MoleculeCSVDataset
(
df
=
df
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
CanonicalAtomFeaturizer
(),
edge_featurizer
=
CanonicalBondFeaturizer
(),
smiles_column
=
'smiles'
,
cache_file_path
=
fname
)
assert
dataset
.
task_names
==
[
'task1'
,
'task2'
]
smiles
,
graph
,
label
,
mask
=
dataset
[
0
]
assert
label
.
shape
[
0
]
==
2
assert
mask
.
shape
[
0
]
==
2
assert
'h'
in
graph
.
ndata
assert
'e'
in
graph
.
edata
# Test task_names
dataset
=
MoleculeCSVDataset
(
df
=
df
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
smiles_column
=
'smiles'
,
cache_file_path
=
fname
,
task_names
=
[
'task1'
])
assert
dataset
.
task_names
==
[
'task1'
]
# Test load
dataset
=
MoleculeCSVDataset
(
df
=
df
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
CanonicalAtomFeaturizer
(),
edge_featurizer
=
None
,
smiles_column
=
'smiles'
,
cache_file_path
=
fname
,
load
=
True
)
smiles
,
graph
,
label
,
mask
=
dataset
[
0
]
assert
'h'
in
graph
.
ndata
assert
'e'
in
graph
.
edata
dataset
=
MoleculeCSVDataset
(
df
=
df
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
CanonicalAtomFeaturizer
(),
edge_featurizer
=
None
,
smiles_column
=
'smiles'
,
cache_file_path
=
fname
,
load
=
False
)
smiles
,
graph
,
label
,
mask
=
dataset
[
0
]
assert
'h'
in
graph
.
ndata
assert
'e'
not
in
graph
.
edata
remove_file
(
fname
)
if
__name__
==
'__main__'
:
test_mol_csv
()
apps/life_sci/tests/data/test_datasets.py
deleted
100644 → 0
View file @
94c67203
import
os
from
dgllife.data
import
*
from
dgllife.data.uspto
import
get_bond_changes
,
process_file
def
remove_file
(
fname
):
if
os
.
path
.
isfile
(
fname
):
try
:
os
.
remove
(
fname
)
except
OSError
:
pass
def
test_pubchem_aromaticity
():
print
(
'Test pubchem aromaticity'
)
dataset
=
PubChemBioAssayAromaticity
()
remove_file
(
'pubchem_aromaticity_dglgraph.bin'
)
def
test_tox21
():
print
(
'Test Tox21'
)
dataset
=
Tox21
()
remove_file
(
'tox21_dglgraph.bin'
)
def
test_alchemy
():
print
(
'Test Alchemy'
)
dataset
=
TencentAlchemyDataset
(
mode
=
'valid'
,
node_featurizer
=
None
,
edge_featurizer
=
None
)
dataset
=
TencentAlchemyDataset
(
mode
=
'valid'
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
load
=
False
)
def
test_pdbbind
():
print
(
'Test PDBBind'
)
dataset
=
PDBBind
(
subset
=
'core'
,
remove_hs
=
True
)
def
test_wln_reaction
():
print
(
'Test datasets for reaction prediction with WLN'
)
reaction1
=
'[CH2:15]([CH:16]([CH3:17])[CH3:18])[Mg+:19].[CH2:20]1[O:21][CH2:22][CH2:23]'
\
'[CH2:24]1.[Cl-:14].[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])[N:8]([O:9]'
\
'[CH3:10])[CH3:11])[cH:12][cH:13]1>>[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])'
\
'[CH2:15][CH:16]([CH3:17])[CH3:18])[cH:12][cH:13]1
\n
'
reaction2
=
'[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])'
\
'[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]'
\
'([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]
\n
'
reactions
=
[
reaction1
,
reaction2
]
# Test utility functions
assert
get_bond_changes
(
reaction2
)
==
{(
'12'
,
'13'
,
0.0
),
(
'12'
,
'15'
,
1.0
)}
with
open
(
'test.txt'
,
'w'
)
as
f
:
for
reac
in
reactions
:
f
.
write
(
reac
)
process_file
(
'test.txt'
)
with
open
(
'test.txt.proc'
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
for
i
in
range
(
len
(
lines
)):
l
=
lines
[
i
].
strip
()
react
=
reactions
[
i
].
strip
()
bond_changes
=
get_bond_changes
(
react
)
assert
l
==
'{} {}'
.
format
(
react
,
';'
.
join
([
'{}-{}-{}'
.
format
(
x
[
0
],
x
[
1
],
x
[
2
])
for
x
in
bond_changes
]))
remove_file
(
'test.txt.proc'
)
# Test configured dataset
dataset
=
WLNCenterDataset
(
'test.txt'
,
'test_graphs.bin'
)
remove_file
(
'test_graphs.bin'
)
with
open
(
'test_candidate_bond_changes.txt'
,
'w'
)
as
f
:
for
reac
in
reactions
:
# simulate fake candidate bond changes
candidate_string
=
''
for
i
in
range
(
2
):
candidate_string
+=
'{} {} {:.1f} {:.3f};'
.
format
(
i
+
1
,
i
+
2
,
0.0
,
0.234
)
candidate_string
+=
'
\n
'
f
.
write
(
candidate_string
)
dataset
=
WLNRankDataset
(
'test.txt.proc'
,
'test_candidate_bond_changes.txt'
,
'train'
)
remove_file
(
'test.txt'
)
remove_file
(
'test.txt.proc'
)
remove_file
(
'test_graphs.bin'
)
remove_file
(
'test_candidate_bond_changes.txt'
)
if
__name__
==
'__main__'
:
test_pubchem_aromaticity
()
test_tox21
()
test_alchemy
()
test_pdbbind
()
test_wln_reaction
()
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment