Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
577cf2e6
Unverified
Commit
577cf2e6
authored
Oct 22, 2019
by
Mufei Li
Committed by
GitHub
Oct 22, 2019
Browse files
[Model Zoo] Refactor and Add Utils for Chemistry (#928)
* Refactor * Add note * Update * CI
parent
0bf3b6dd
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1169 additions
and
237 deletions
+1169
-237
docs/source/api/python/data.rst
docs/source/api/python/data.rst
+53
-6
examples/pytorch/model_zoo/chem/generative_models/dgmg/utils.py
...es/pytorch/model_zoo/chem/generative_models/dgmg/utils.py
+2
-2
examples/pytorch/model_zoo/chem/property_prediction/classification.py
...orch/model_zoo/chem/property_prediction/classification.py
+29
-31
examples/pytorch/model_zoo/chem/property_prediction/configure.py
...s/pytorch/model_zoo/chem/property_prediction/configure.py
+16
-5
examples/pytorch/model_zoo/chem/property_prediction/regression.py
.../pytorch/model_zoo/chem/property_prediction/regression.py
+39
-31
examples/pytorch/model_zoo/chem/property_prediction/utils.py
examples/pytorch/model_zoo/chem/property_prediction/utils.py
+150
-36
python/dgl/data/chem/alchemy.py
python/dgl/data/chem/alchemy.py
+38
-18
python/dgl/data/chem/csv_dataset.py
python/dgl/data/chem/csv_dataset.py
+27
-17
python/dgl/data/chem/tox21.py
python/dgl/data/chem/tox21.py
+15
-8
python/dgl/data/chem/utils.py
python/dgl/data/chem/utils.py
+800
-83
No files found.
docs/source/api/python/data.rst
View file @
577cf2e6
...
@@ -142,19 +142,66 @@ Molecular Graphs
...
@@ -142,19 +142,66 @@ Molecular Graphs
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Featurization
Featurization
Utils
`````````````
`````````````
``````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds).
featurization methods/utilities:
General utils:
.. autosummary::
.. autosummary::
:toctree: ../../generated/
:toctree: ../../generated/
chem.one_hot_encoding
chem.one_hot_encoding
chem.ConcatFeaturizer
chem.ConcatFeaturizer.__call__
Utils for atom featurization:
.. autosummary::
:toctree: ../../generated/
chem.atom_type_one_hot
chem.atomic_number_one_hot
chem.atomic_number
chem.atom_degree_one_hot
chem.atom_degree
chem.atom_total_degree_one_hot
chem.atom_total_degree
chem.atom_implicit_valence_one_hot
chem.atom_implicit_valence
chem.atom_hybridization_one_hot
chem.atom_total_num_H_one_hot
chem.atom_total_num_H
chem.atom_formal_charge_one_hot
chem.atom_formal_charge
chem.atom_num_radical_electrons_one_hot
chem.atom_num_radical_electrons
chem.atom_is_aromatic_one_hot
chem.atom_is_aromatic
chem.atom_chiral_tag_one_hot
chem.atom_mass
chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer.feat_size
chem.BaseAtomFeaturizer.__call__
chem.CanonicalAtomFeaturizer
chem.CanonicalAtomFeaturizer
Utils for bond featurization:
.. autosummary::
:toctree: ../../generated/
chem.bond_type_one_hot
chem.bond_is_conjugated_one_hot
chem.bond_is_conjugated
chem.bond_is_in_ring_one_hot
chem.bond_is_in_ring
chem.bond_stereo_one_hot
chem.BaseBondFeaturizer
chem.BaseBondFeaturizer.feat_size
chem.BaseBondFeaturizer.__call__
chem.CanonicalBondFeaturizer
Graph Construction
Graph Construction
``````````````````
``````````````````
...
@@ -164,9 +211,9 @@ Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects ar
...
@@ -164,9 +211,9 @@ Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects ar
:toctree: ../../generated/
:toctree: ../../generated/
chem.mol_to_graph
chem.mol_to_graph
chem.smile_to_bigraph
chem.smile
s
_to_bigraph
chem.mol_to_bigraph
chem.mol_to_bigraph
chem.smile_to_complete_graph
chem.smile
s
_to_complete_graph
chem.mol_to_complete_graph
chem.mol_to_complete_graph
Dataset Classes
Dataset Classes
...
...
examples/pytorch/model_zoo/chem/generative_models/dgmg/utils.py
View file @
577cf2e6
...
@@ -448,7 +448,7 @@ def get_atom_and_bond_types(smiles, log=True):
...
@@ -448,7 +448,7 @@ def get_atom_and_bond_types(smiles, log=True):
for
i
,
s
in
enumerate
(
smiles
):
for
i
,
s
in
enumerate
(
smiles
):
if
log
:
if
log
:
print
(
'Processing smile {:d}/{:d}'
.
format
(
i
+
1
,
n_smiles
))
print
(
'Processing smile
s
{:d}/{:d}'
.
format
(
i
+
1
,
n_smiles
))
mol
=
smiles_to_standard_mol
(
s
)
mol
=
smiles_to_standard_mol
(
s
)
if
mol
is
None
:
if
mol
is
None
:
...
@@ -517,7 +517,7 @@ def eval_decisions(env, decisions):
...
@@ -517,7 +517,7 @@ def eval_decisions(env, decisions):
return
env
.
get_current_smiles
()
return
env
.
get_current_smiles
()
def
get_DGMG_smile
(
env
,
mol
):
def
get_DGMG_smile
(
env
,
mol
):
"""Mimics the reproduced SMILE with DGMG for a molecule.
"""Mimics the reproduced SMILE
S
with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
get if we want to generate it with DGMG. This is an important
...
...
examples/pytorch/model_zoo/chem/property_prediction/classification.py
View file @
577cf2e6
import
numpy
as
np
import
torch
import
torch
from
torch.nn
import
BCEWithLogitsLoss
from
torch.nn
import
BCEWithLogitsLoss
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
dgl
import
model_zoo
from
dgl
import
model_zoo
from
dgl.data.utils
import
split_dataset
from
utils
import
Meter
,
EarlyStopping
,
collate_molgraphs_for_classification
,
set_random_seed
from
utils
import
Meter
,
EarlyStopping
,
collate_molgraphs
,
set_random_seed
,
\
load_dataset_for_classification
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
loss_criterion
,
optimizer
):
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
loss_criterion
,
optimizer
):
model
.
train
()
model
.
train
()
train_meter
=
Meter
()
train_meter
=
Meter
()
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
mask
=
batch_data
smiles
,
bg
,
labels
,
mask
s
=
batch_data
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
,
labels
,
mask
=
atom_feats
.
to
(
args
[
'device'
]),
\
atom_feats
,
labels
,
mask
s
=
atom_feats
.
to
(
args
[
'device'
]),
\
labels
.
to
(
args
[
'device'
]),
\
labels
.
to
(
args
[
'device'
]),
\
mask
.
to
(
args
[
'device'
])
mask
s
.
to
(
args
[
'device'
])
logits
=
model
(
bg
,
atom_feats
)
logits
=
model
(
bg
,
atom_feats
)
# Mask non-existing labels
# Mask non-existing labels
loss
=
(
loss_criterion
(
logits
,
labels
)
*
(
mask
!=
0
).
float
()).
mean
()
loss
=
(
loss_criterion
(
logits
,
labels
)
*
(
mask
s
!=
0
).
float
()).
mean
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
print
(
'epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'
.
format
(
print
(
'epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
batch_id
+
1
,
len
(
data_loader
),
loss
.
item
()))
epoch
+
1
,
args
[
'num_epochs'
],
batch_id
+
1
,
len
(
data_loader
),
loss
.
item
()))
train_meter
.
update
(
logits
,
labels
,
mask
)
train_meter
.
update
(
logits
,
labels
,
mask
s
)
train_
roc_auc
=
train_meter
.
roc_auc_averaged_over_tasks
(
)
train_
score
=
np
.
mean
(
train_meter
.
compute_metric
(
args
[
'metric_name'
])
)
print
(
'epoch {:d}/{:d}, training
roc-auc score
{:.4f}'
.
format
(
print
(
'epoch {:d}/{:d}, training
{}
{:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
train_roc_auc
))
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
train_score
))
def
run_an_eval_epoch
(
args
,
model
,
data_loader
):
def
run_an_eval_epoch
(
args
,
model
,
data_loader
):
model
.
eval
()
model
.
eval
()
eval_meter
=
Meter
()
eval_meter
=
Meter
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
mask
=
batch_data
smiles
,
bg
,
labels
,
mask
s
=
batch_data
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
,
labels
=
atom_feats
.
to
(
args
[
'device'
]),
labels
.
to
(
args
[
'device'
])
atom_feats
,
labels
=
atom_feats
.
to
(
args
[
'device'
]),
labels
.
to
(
args
[
'device'
])
logits
=
model
(
bg
,
atom_feats
)
logits
=
model
(
bg
,
atom_feats
)
eval_meter
.
update
(
logits
,
labels
,
mask
)
eval_meter
.
update
(
logits
,
labels
,
mask
s
)
return
eval_meter
.
roc_auc_averaged_over_tasks
(
)
return
np
.
mean
(
eval_meter
.
compute_metric
(
args
[
'metric_name'
])
)
def
main
(
args
):
def
main
(
args
):
args
[
'device'
]
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
args
[
'device'
]
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
set_random_seed
()
set_random_seed
()
# Interchangeable with other datasets
# Interchangeable with other datasets
if
args
[
'dataset'
]
==
'Tox21'
:
dataset
,
train_set
,
val_set
,
test_set
=
load_dataset_for_classification
(
args
)
from
dgl.data.chem
import
Tox21
train_loader
=
DataLoader
(
train_set
,
batch_size
=
args
[
'batch_size'
],
dataset
=
Tox21
()
collate_fn
=
collate_molgraphs
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
args
[
'batch_size'
],
trainset
,
valset
,
testset
=
split_dataset
(
dataset
,
args
[
'train_val_test_split'
])
collate_fn
=
collate_molgraphs
)
train_loader
=
DataLoader
(
trainset
,
batch_size
=
args
[
'batch_size'
],
test_loader
=
DataLoader
(
test_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs_for_classification
)
collate_fn
=
collate_molgraphs
)
val_loader
=
DataLoader
(
valset
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs_for_classification
)
test_loader
=
DataLoader
(
testset
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs_for_classification
)
if
args
[
'pre_trained'
]:
if
args
[
'pre_trained'
]:
args
[
'num_epochs'
]
=
0
args
[
'num_epochs'
]
=
0
...
@@ -87,17 +84,18 @@ def main(args):
...
@@ -87,17 +84,18 @@ def main(args):
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_criterion
,
optimizer
)
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_criterion
,
optimizer
)
# Validation and early stop
# Validation and early stop
val_roc_auc
=
run_an_eval_epoch
(
args
,
model
,
val_loader
)
val_score
=
run_an_eval_epoch
(
args
,
model
,
val_loader
)
early_stop
=
stopper
.
step
(
val_roc_auc
,
model
)
early_stop
=
stopper
.
step
(
val_score
,
model
)
print
(
'epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'
.
format
(
print
(
'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
val_roc_auc
,
stopper
.
best_score
))
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
val_score
,
args
[
'metric_name'
],
stopper
.
best_score
))
if
early_stop
:
if
early_stop
:
break
break
if
not
args
[
'pre_trained'
]:
if
not
args
[
'pre_trained'
]:
stopper
.
load_checkpoint
(
model
)
stopper
.
load_checkpoint
(
model
)
test_
roc_auc
=
run_an_eval_epoch
(
args
,
model
,
test_loader
)
test_
score
=
run_an_eval_epoch
(
args
,
model
,
test_loader
)
print
(
'test
roc-auc score
{:.4f}'
.
format
(
test_roc_auc
))
print
(
'test
{}
{:.4f}'
.
format
(
args
[
'metric_name'
],
test_score
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
examples/pytorch/model_zoo/chem/property_prediction/configure.py
View file @
577cf2e6
from
dgl.data.chem
import
CanonicalAtomFeaturizer
GCN_Tox21
=
{
GCN_Tox21
=
{
'batch_size'
:
128
,
'batch_size'
:
128
,
'lr'
:
1e-3
,
'lr'
:
1e-3
,
...
@@ -7,7 +9,9 @@ GCN_Tox21 = {
...
@@ -7,7 +9,9 @@ GCN_Tox21 = {
'in_feats'
:
74
,
'in_feats'
:
74
,
'gcn_hidden_feats'
:
[
64
,
64
],
'gcn_hidden_feats'
:
[
64
,
64
],
'classifier_hidden_feats'
:
64
,
'classifier_hidden_feats'
:
64
,
'patience'
:
10
'patience'
:
10
,
'atom_featurizer'
:
CanonicalAtomFeaturizer
(),
'metric_name'
:
'roc_auc'
}
}
GAT_Tox21
=
{
GAT_Tox21
=
{
...
@@ -20,15 +24,20 @@ GAT_Tox21 = {
...
@@ -20,15 +24,20 @@ GAT_Tox21 = {
'gat_hidden_feats'
:
[
32
,
32
],
'gat_hidden_feats'
:
[
32
,
32
],
'classifier_hidden_feats'
:
64
,
'classifier_hidden_feats'
:
64
,
'num_heads'
:
[
4
,
4
],
'num_heads'
:
[
4
,
4
],
'patience'
:
10
'patience'
:
10
,
'atom_featurizer'
:
CanonicalAtomFeaturizer
(),
'metric_name'
:
'roc_auc'
}
}
MPNN_Alchemy
=
{
MPNN_Alchemy
=
{
'batch_size'
:
16
,
'batch_size'
:
16
,
'num_epochs'
:
250
,
'num_epochs'
:
250
,
'node_in_feats'
:
15
,
'edge_in_feats'
:
5
,
'output_dim'
:
12
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'lr'
:
0.0001
,
'patience'
:
50
'patience'
:
50
,
'metric_name'
:
'l1'
}
}
SCHNET_Alchemy
=
{
SCHNET_Alchemy
=
{
...
@@ -37,7 +46,8 @@ SCHNET_Alchemy = {
...
@@ -37,7 +46,8 @@ SCHNET_Alchemy = {
'norm'
:
True
,
'norm'
:
True
,
'output_dim'
:
12
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'lr'
:
0.0001
,
'patience'
:
50
'patience'
:
50
,
'metric_name'
:
'l1'
}
}
MGCN_Alchemy
=
{
MGCN_Alchemy
=
{
...
@@ -46,7 +56,8 @@ MGCN_Alchemy = {
...
@@ -46,7 +56,8 @@ MGCN_Alchemy = {
'norm'
:
True
,
'norm'
:
True
,
'output_dim'
:
12
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'lr'
:
0.0001
,
'patience'
:
50
'patience'
:
50
,
'metric_name'
:
'l1'
}
}
experiment_configures
=
{
experiment_configures
=
{
...
...
examples/pytorch/model_zoo/chem/property_prediction/regression.py
View file @
577cf2e6
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
dgl
import
model_zoo
from
dgl
import
model_zoo
from
utils
import
set_random_seed
,
collate_molgraphs_for_regression
,
EarlyStopping
from
utils
import
Meter
,
set_random_seed
,
collate_molgraphs
,
EarlyStopping
,
\
load_dataset_for_regression
def
regress
(
args
,
model
,
bg
):
def
regress
(
args
,
model
,
bg
):
if
args
[
'model'
]
==
'MPNN'
:
if
args
[
'model'
]
==
'MPNN'
:
...
@@ -20,36 +22,35 @@ def regress(args, model, bg):
...
@@ -20,36 +22,35 @@ def regress(args, model, bg):
return
model
(
bg
,
node_types
,
edge_distances
)
return
model
(
bg
,
node_types
,
edge_distances
)
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
loss_criterion
,
score_criterion
,
optimizer
):
loss_criterion
,
optimizer
):
model
.
train
()
model
.
train
()
total_loss
,
total_score
=
0
,
0
train_meter
=
Meter
()
total_loss
=
0
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
=
batch_data
smiles
,
bg
,
labels
,
masks
=
batch_data
labels
=
labels
.
to
(
args
[
'device'
])
labels
,
masks
=
labels
.
to
(
args
[
'device'
])
,
masks
.
to
(
args
[
'device'
])
prediction
=
regress
(
args
,
model
,
bg
)
prediction
=
regress
(
args
,
model
,
bg
)
loss
=
loss_criterion
(
prediction
,
labels
)
loss
=
(
loss_criterion
(
prediction
,
labels
)
*
(
masks
!=
0
).
float
()).
mean
()
score
=
score_criterion
(
prediction
,
labels
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
total_loss
+=
loss
.
detach
().
item
()
*
bg
.
batch_size
total_loss
+=
loss
.
detach
().
item
()
*
bg
.
batch_size
t
otal_score
+=
score
.
detach
().
item
()
*
bg
.
batch_size
t
rain_meter
.
update
(
prediction
,
labels
,
masks
)
total_loss
/=
len
(
data_loader
.
dataset
)
total_loss
/=
len
(
data_loader
.
dataset
)
total_score
/
=
len
(
data_loader
.
dataset
)
total_score
=
np
.
mean
(
train_meter
.
compute_metric
(
args
[
'metric_name'
])
)
print
(
'epoch {:d}/{:d}, training loss {:.4f}, training
score
{:.4f}'
.
format
(
print
(
'epoch {:d}/{:d}, training loss {:.4f}, training
{}
{:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
total_loss
,
total_score
))
epoch
+
1
,
args
[
'num_epochs'
],
total_loss
,
args
[
'metric_name'
],
total_score
))
def
run_an_eval_epoch
(
args
,
model
,
data_loader
,
score_criterion
):
def
run_an_eval_epoch
(
args
,
model
,
data_loader
):
model
.
eval
()
model
.
eval
()
total_score
=
0
eval_meter
=
Meter
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
=
batch_data
smiles
,
bg
,
labels
,
masks
=
batch_data
labels
=
labels
.
to
(
args
[
'device'
])
labels
=
labels
.
to
(
args
[
'device'
])
prediction
=
regress
(
args
,
model
,
bg
)
prediction
=
regress
(
args
,
model
,
bg
)
score
=
score_criterion
(
prediction
,
labels
)
eval_meter
.
update
(
prediction
,
labels
,
masks
)
total_score
+=
score
.
detach
().
item
()
*
bg
.
batch_size
total_score
=
np
.
mean
(
eval_meter
.
compute_metric
(
args
[
'metric_name'
]))
total_score
/=
len
(
data_loader
.
dataset
)
return
total_score
return
total_score
def
main
(
args
):
def
main
(
args
):
...
@@ -57,20 +58,22 @@ def main(args):
...
@@ -57,20 +58,22 @@ def main(args):
set_random_seed
()
set_random_seed
()
# Interchangeable with other datasets
# Interchangeable with other datasets
if
args
[
'dataset'
]
==
'Alchemy'
:
train_set
,
val_set
,
test_set
=
load_dataset_for_regression
(
args
)
from
dgl.data.chem
import
TencentAlchemyDataset
train_set
=
TencentAlchemyDataset
(
mode
=
'dev'
)
val_set
=
TencentAlchemyDataset
(
mode
=
'valid'
)
train_loader
=
DataLoader
(
dataset
=
train_set
,
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
args
[
'batch_size'
],
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
_for_regression
)
collate_fn
=
collate_molgraphs
)
val_loader
=
DataLoader
(
dataset
=
val_set
,
val_loader
=
DataLoader
(
dataset
=
val_set
,
batch_size
=
args
[
'batch_size'
],
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs_for_regression
)
collate_fn
=
collate_molgraphs
)
if
test_set
is
not
None
:
test_loader
=
DataLoader
(
dataset
=
test_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
)
if
args
[
'model'
]
==
'MPNN'
:
if
args
[
'model'
]
==
'MPNN'
:
model
=
model_zoo
.
chem
.
MPNNModel
(
output_dim
=
args
[
'output_dim'
])
model
=
model_zoo
.
chem
.
MPNNModel
(
node_input_dim
=
args
[
'node_in_feats'
],
edge_input_dim
=
args
[
'edge_in_feats'
],
output_dim
=
args
[
'output_dim'
])
elif
args
[
'model'
]
==
'SCHNET'
:
elif
args
[
'model'
]
==
'SCHNET'
:
model
=
model_zoo
.
chem
.
SchNet
(
norm
=
args
[
'norm'
],
output_dim
=
args
[
'output_dim'
])
model
=
model_zoo
.
chem
.
SchNet
(
norm
=
args
[
'norm'
],
output_dim
=
args
[
'output_dim'
])
model
.
set_mean_std
(
train_set
.
mean
,
train_set
.
std
,
args
[
'device'
])
model
.
set_mean_std
(
train_set
.
mean
,
train_set
.
std
,
args
[
'device'
])
...
@@ -79,23 +82,28 @@ def main(args):
...
@@ -79,23 +82,28 @@ def main(args):
model
.
set_mean_std
(
train_set
.
mean
,
train_set
.
std
,
args
[
'device'
])
model
.
set_mean_std
(
train_set
.
mean
,
train_set
.
std
,
args
[
'device'
])
model
.
to
(
args
[
'device'
])
model
.
to
(
args
[
'device'
])
loss_fn
=
nn
.
MSELoss
()
loss_fn
=
nn
.
MSELoss
(
reduction
=
'none'
)
score_fn
=
nn
.
L1Loss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
[
'lr'
])
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
[
'lr'
])
stopper
=
EarlyStopping
(
mode
=
'lower'
,
patience
=
args
[
'patience'
])
stopper
=
EarlyStopping
(
mode
=
'lower'
,
patience
=
args
[
'patience'
])
for
epoch
in
range
(
args
[
'num_epochs'
]):
for
epoch
in
range
(
args
[
'num_epochs'
]):
# Train
# Train
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_fn
,
score_fn
,
optimizer
)
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_fn
,
optimizer
)
# Validation and early stop
# Validation and early stop
val_score
=
run_an_eval_epoch
(
args
,
model
,
val_loader
,
score_fn
)
val_score
=
run_an_eval_epoch
(
args
,
model
,
val_loader
)
early_stop
=
stopper
.
step
(
val_score
,
model
)
early_stop
=
stopper
.
step
(
val_score
,
model
)
print
(
'epoch {:d}/{:d}, validation score {:.4f}, best validation score {:.4f}'
.
format
(
print
(
'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
val_score
,
stopper
.
best_score
))
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
val_score
,
args
[
'metric_name'
],
stopper
.
best_score
))
if
early_stop
:
if
early_stop
:
break
break
if
test_set
is
not
None
:
stopper
.
load_checkpoint
(
model
)
test_score
=
run_an_eval_epoch
(
args
,
model
,
test_loader
)
print
(
'test {} {:.4f}'
.
format
(
args
[
'metric_name'
],
test_score
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
argparse
import
argparse
...
...
examples/pytorch/model_zoo/chem/property_prediction/utils.py
View file @
577cf2e6
import
datetime
import
datetime
import
dgl
import
dgl
import
math
import
numpy
as
np
import
numpy
as
np
import
random
import
random
import
torch
import
torch
from
sklearn.metrics
import
roc_auc_score
import
torch.nn.functional
as
F
from
dgl.data.utils
import
split_dataset
from
sklearn.metrics
import
roc_auc_score
,
mean_squared_error
def
set_random_seed
(
seed
=
0
):
def
set_random_seed
(
seed
=
0
):
"""Set random seed.
"""Set random seed.
...
@@ -45,13 +49,13 @@ class Meter(object):
...
@@ -45,13 +49,13 @@ class Meter(object):
self
.
y_true
.
append
(
y_true
.
detach
().
cpu
())
self
.
y_true
.
append
(
y_true
.
detach
().
cpu
())
self
.
mask
.
append
(
mask
.
detach
().
cpu
())
self
.
mask
.
append
(
mask
.
detach
().
cpu
())
def
roc_auc_
averaged_over_tasks
(
self
):
def
roc_auc_
score
(
self
):
"""Compute roc-auc score for each task
and return the average
.
"""Compute roc-auc score for each task.
Returns
Returns
-------
-------
float
list of
float
roc-auc score
averaged ove
r all tasks
roc-auc score
fo
r all tasks
"""
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
...
@@ -60,13 +64,83 @@ class Meter(object):
...
@@ -60,13 +64,83 @@ class Meter(object):
# This assumes binary case only
# This assumes binary case only
y_pred
=
torch
.
sigmoid
(
y_pred
)
y_pred
=
torch
.
sigmoid
(
y_pred
)
n_tasks
=
y_true
.
shape
[
1
]
n_tasks
=
y_true
.
shape
[
1
]
total_score
=
0
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
].
numpy
()
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
].
numpy
()
scores
.
append
(
roc_auc_score
(
task_y_true
,
task_y_pred
))
return
scores
def
l1_loss
(
self
,
reduction
):
"""Compute l1 loss for each task.
Returns
-------
list of float
l1 loss for all tasks
reduction : str
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
n_tasks
=
y_true
.
shape
[
1
]
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
]
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
]
scores
.
append
(
F
.
l1_loss
(
task_y_true
,
task_y_pred
,
reduction
=
reduction
).
item
())
return
scores
def
rmse
(
self
):
"""Compute RMSE for each task.
Returns
-------
list of float
rmse for all tasks
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
n_data
,
n_tasks
=
y_true
.
shape
scores
=
[]
for
task
in
range
(
n_tasks
):
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
].
numpy
()
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
].
numpy
()
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
].
numpy
()
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
].
numpy
()
total_score
+=
roc_auc_score
(
task_y_true
,
task_y_pred
)
scores
.
append
(
math
.
sqrt
(
mean_squared_error
(
task_y_true
,
task_y_pred
)))
return
total_score
/
n_tasks
return
scores
def
compute_metric
(
self
,
metric_name
,
reduction
=
'mean'
):
"""Compute metric for each task.
Parameters
----------
metric_name : str
Name for the metric to compute.
reduction : str
Only comes into effect when the metric_name is l1_loss.
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
Returns
-------
list of float
Metric value for each task
"""
assert
metric_name
in
[
'roc_auc'
,
'l1'
,
'rmse'
],
\
'Expect metric name to be "roc_auc", "l1" or "rmse", got {}'
.
format
(
metric_name
)
assert
reduction
in
[
'mean'
,
'sum'
]
if
metric_name
==
'roc_auc'
:
return
self
.
roc_auc_score
()
if
metric_name
==
'l1'
:
return
self
.
l1_loss
(
reduction
)
if
metric_name
==
'rmse'
:
return
self
.
rmse
()
class
EarlyStopping
(
object
):
class
EarlyStopping
(
object
):
"""Early stop performing
"""Early stop performing
...
@@ -131,14 +205,15 @@ class EarlyStopping(object):
...
@@ -131,14 +205,15 @@ class EarlyStopping(object):
'''Load model saved with early stopping.'''
'''Load model saved with early stopping.'''
model
.
load_state_dict
(
torch
.
load
(
self
.
filename
)[
'model_state_dict'
])
model
.
load_state_dict
(
torch
.
load
(
self
.
filename
)[
'model_state_dict'
])
def
collate_molgraphs
_for_classification
(
data
):
def
collate_molgraphs
(
data
):
"""Batching a list of datapoints for dataloader
in classification tasks
.
"""Batching a list of datapoints for dataloader.
Parameters
Parameters
----------
----------
data : list of 4-tuples
data : list of
3-tuples or
4-tuples
.
Each tuple is for a single datapoint, consisting of
Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph, all-task labels and all-task weights
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
Returns
-------
-------
...
@@ -149,40 +224,79 @@ def collate_molgraphs_for_classification(data):
...
@@ -149,40 +224,79 @@ def collate_molgraphs_for_classification(data):
labels : Tensor of dtype float32 and shape (B, T)
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
T is the number of total tasks.
weights : Tensor of dtype float32 and shape (B, T)
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint weights. T is the number of
Batched datapoint binary mask, indicating the
total tasks.
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
"""
smiles
,
graphs
,
labels
,
mask
=
map
(
list
,
zip
(
*
data
))
assert
len
(
data
[
0
])
in
[
3
,
4
],
\
'Expect the tuple to be of length 3 or 4, got {:d}'
.
format
(
len
(
data
[
0
]))
if
len
(
data
[
0
])
==
3
:
smiles
,
graphs
,
labels
=
map
(
list
,
zip
(
*
data
))
masks
=
None
else
:
smiles
,
graphs
,
labels
,
masks
=
map
(
list
,
zip
(
*
data
))
bg
=
dgl
.
batch
(
graphs
)
bg
=
dgl
.
batch
(
graphs
)
bg
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
bg
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
bg
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
bg
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
labels
=
torch
.
stack
(
labels
,
dim
=
0
)
labels
=
torch
.
stack
(
labels
,
dim
=
0
)
mask
=
torch
.
stack
(
mask
,
dim
=
0
)
return
smiles
,
bg
,
labels
,
mask
def
collate_molgraphs_for_regression
(
data
):
if
masks
is
None
:
"""Batching a list of datapoints for dataloader in regression tasks.
masks
=
torch
.
ones
(
labels
.
shape
)
else
:
masks
=
torch
.
stack
(
masks
,
dim
=
0
)
return
smiles
,
bg
,
labels
,
masks
def
load_dataset_for_classification
(
args
):
"""Load dataset for classification tasks.
Parameters
Parameters
----------
----------
data : list of 3-tuples
args : dict
Each tuple is for a single datapoint, consisting of
Configurations.
a SMILE, a DGLGraph and all-task labels.
Returns
Returns
-------
-------
smiles : list
dataset
List of smiles
The whole dataset.
bg : BatchedDGLGraph
train_set
Batched DGLGraphs
Subset for training.
labels : Tensor of dtype float32 and shape (B, T)
val_set
Batched datapoint labels. B is len(data) and
Subset for validation.
T is the number of total tasks.
test_set
Subset for test.
"""
"""
smiles
,
graphs
,
labels
=
map
(
list
,
zip
(
*
data
))
assert
args
[
'dataset'
]
in
[
'Tox21'
]
bg
=
dgl
.
batch
(
graphs
)
if
args
[
'dataset'
]
==
'Tox21'
:
bg
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
from
dgl.data.chem
import
Tox21
bg
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
dataset
=
Tox21
(
atom_featurizer
=
args
[
'atom_featurizer'
])
labels
=
torch
.
stack
(
labels
,
dim
=
0
)
train_set
,
val_set
,
test_set
=
split_dataset
(
dataset
,
args
[
'train_val_test_split'
])
return
smiles
,
bg
,
labels
return
dataset
,
train_set
,
val_set
,
test_set
def
load_dataset_for_regression
(
args
):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert
args
[
'dataset'
]
in
[
'Alchemy'
]
if
args
[
'dataset'
]
==
'Alchemy'
:
from
dgl.data.chem
import
TencentAlchemyDataset
train_set
=
TencentAlchemyDataset
(
mode
=
'dev'
)
val_set
=
TencentAlchemyDataset
(
mode
=
'valid'
)
test_set
=
None
return
train_set
,
val_set
,
test_set
python/dgl/data/chem/alchemy.py
View file @
577cf2e6
...
@@ -10,7 +10,8 @@ import pickle
...
@@ -10,7 +10,8 @@ import pickle
import
zipfile
import
zipfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
.utils
import
mol_to_complete_graph
from
.utils
import
mol_to_complete_graph
,
atom_type_one_hot
,
atom_hybridization_one_hot
,
\
atom_is_aromatic
from
..utils
import
download
,
get_download_dir
,
_get_dgl_url
,
save_graphs
,
load_graphs
from
..utils
import
download
,
get_download_dir
,
_get_dgl_url
,
save_graphs
,
load_graphs
from
...
import
backend
as
F
from
...
import
backend
as
F
...
@@ -59,25 +60,19 @@ def alchemy_nodes(mol):
...
@@ -59,25 +60,19 @@ def alchemy_nodes(mol):
num_atoms
=
mol
.
GetNumAtoms
()
num_atoms
=
mol
.
GetNumAtoms
()
for
u
in
range
(
num_atoms
):
for
u
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
u
)
atom
=
mol
.
GetAtomWithIdx
(
u
)
symbol
=
atom
.
GetSymbol
()
atom_type
=
atom
.
GetAtomicNum
()
atom_type
=
atom
.
GetAtomicNum
()
aromatic
=
atom
.
GetIsAromatic
()
hybridization
=
atom
.
GetHybridization
()
num_h
=
atom
.
GetTotalNumHs
()
num_h
=
atom
.
GetTotalNumHs
()
atom_feats_dict
[
'node_type'
].
append
(
atom_type
)
atom_feats_dict
[
'node_type'
].
append
(
atom_type
)
h_u
=
[]
h_u
=
[]
h_u
+=
[
int
(
symbol
==
x
)
for
x
in
[
'H'
,
'C'
,
'N'
,
'O'
,
'F'
,
'S'
,
'Cl'
]
]
h_u
+=
atom_type_one_hot
(
atom
,
[
'H'
,
'C'
,
'N'
,
'O'
,
'F'
,
'S'
,
'Cl'
]
)
h_u
.
append
(
atom_type
)
h_u
.
append
(
atom_type
)
h_u
.
append
(
is_acceptor
[
u
])
h_u
.
append
(
is_acceptor
[
u
])
h_u
.
append
(
is_donor
[
u
])
h_u
.
append
(
is_donor
[
u
])
h_u
.
append
(
int
(
aromatic
))
h_u
+=
atom_is_aromatic
(
atom
)
h_u
+=
[
h_u
+=
atom_hybridization_one_hot
(
atom
,
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
int
(
hybridization
==
x
)
for
x
in
(
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
)
Chem
.
rdchem
.
HybridizationType
.
SP3
])
]
h_u
.
append
(
num_h
)
h_u
.
append
(
num_h
)
atom_feats_dict
[
'n_feat'
].
append
(
F
.
tensor
(
np
.
array
(
h_u
).
astype
(
np
.
float32
)))
atom_feats_dict
[
'n_feat'
].
append
(
F
.
tensor
(
np
.
array
(
h_u
).
astype
(
np
.
float32
)))
...
@@ -155,9 +150,34 @@ class TencentAlchemyDataset(object):
...
@@ -155,9 +150,34 @@ class TencentAlchemyDataset(object):
contest is ongoing.
contest is ongoing.
from_raw : bool
from_raw : bool
Whether to process the dataset from scratch or use a
Whether to process the dataset from scratch or use a
processed one for faster speed. Default to be False.
processed one for faster speed. If you use different ways
to featurize atoms or bonds, you should set this to be True.
Default to be False.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we store the atom atomic numbers
under the name ``"node_type"`` and store the atom features under the
name ``"n_feat"``. The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we store the distance between the
end atoms under the name ``"distance"`` and store the bond features under
the name ``"e_feat"``. The bond features are one-hot encodings of the bond type.
"""
"""
def
__init__
(
self
,
mode
=
'dev'
,
from_raw
=
False
):
def
__init__
(
self
,
mode
=
'dev'
,
from_raw
=
False
,
mol_to_graph
=
mol_to_complete_graph
,
atom_featurizer
=
alchemy_nodes
,
bond_featurizer
=
alchemy_edges
):
if
mode
==
'test'
:
if
mode
==
'test'
:
raise
ValueError
(
'The test mode is not supported before '
raise
ValueError
(
'The test mode is not supported before '
'the Alchemy contest finishes.'
)
'the Alchemy contest finishes.'
)
...
@@ -185,9 +205,9 @@ class TencentAlchemyDataset(object):
...
@@ -185,9 +205,9 @@ class TencentAlchemyDataset(object):
archive
.
extractall
(
file_dir
)
archive
.
extractall
(
file_dir
)
archive
.
close
()
archive
.
close
()
self
.
_load
()
self
.
_load
(
mol_to_graph
,
atom_featurizer
,
bond_featurizer
)
def
_load
(
self
):
def
_load
(
self
,
mol_to_graph
,
atom_featurizer
,
bond_featurizer
):
if
not
self
.
from_raw
:
if
not
self
.
from_raw
:
self
.
graphs
,
label_dict
=
load_graphs
(
osp
.
join
(
self
.
file_dir
,
"%s_graphs.bin"
%
self
.
mode
))
self
.
graphs
,
label_dict
=
load_graphs
(
osp
.
join
(
self
.
file_dir
,
"%s_graphs.bin"
%
self
.
mode
))
self
.
labels
=
label_dict
[
'labels'
]
self
.
labels
=
label_dict
[
'labels'
]
...
@@ -210,8 +230,8 @@ class TencentAlchemyDataset(object):
...
@@ -210,8 +230,8 @@ class TencentAlchemyDataset(object):
for
mol
,
label
in
zip
(
supp
,
self
.
target
.
iterrows
()):
for
mol
,
label
in
zip
(
supp
,
self
.
target
.
iterrows
()):
cnt
+=
1
cnt
+=
1
print
(
'Processing molecule {:d}/{:d}'
.
format
(
cnt
,
dataset_size
))
print
(
'Processing molecule {:d}/{:d}'
.
format
(
cnt
,
dataset_size
))
graph
=
mol_to_
complete_
graph
(
mol
,
atom_featurizer
=
a
lchemy_nodes
,
graph
=
mol_to_graph
(
mol
,
atom_featurizer
=
a
tom_featurizer
,
bond_featurizer
=
alchemy_edges
)
bond_featurizer
=
bond_featurizer
)
smiles
=
Chem
.
MolToSmiles
(
mol
)
smiles
=
Chem
.
MolToSmiles
(
mol
)
self
.
smiles
.
append
(
smiles
)
self
.
smiles
.
append
(
smiles
)
self
.
graphs
.
append
(
graph
)
self
.
graphs
.
append
(
graph
)
...
...
python/dgl/data/chem/csv_dataset.py
View file @
577cf2e6
...
@@ -5,10 +5,8 @@ import numpy as np
...
@@ -5,10 +5,8 @@ import numpy as np
import
os
import
os
import
sys
import
sys
from
.utils
import
smile_to_bigraph
from
..utils
import
save_graphs
,
load_graphs
from
..utils
import
save_graphs
,
load_graphs
from
...
import
backend
as
F
from
...
import
backend
as
F
from
...graph
import
DGLGraph
class
MoleculeCSVDataset
(
object
):
class
MoleculeCSVDataset
(
object
):
"""MoleculeCSVDataset
"""MoleculeCSVDataset
...
@@ -27,28 +25,33 @@ class MoleculeCSVDataset(object):
...
@@ -27,28 +25,33 @@ class MoleculeCSVDataset(object):
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
Column names other than smiles column would be considered as task names.
smile_to_graph: callable, str -> DGLGraph
smiles_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
A function turning a SMILES into a DGLGraph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
smile_column: str
Featurization for atoms in a molecule, which can be used to update
Column name that including smiles
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str
cache_file_path: str
Path to store the preprocessed data
Path to store the preprocessed data
.
"""
"""
def
__init__
(
self
,
df
,
smile_to_graph
=
smile_to_bigraph
,
smile_column
=
'smiles'
,
def
__init__
(
self
,
df
,
smile
s
_to_graph
,
atom_featurizer
,
bond_featurizer
,
cache_file_path
=
"csvdata_dglgraph.bin"
):
smiles_column
,
cache_file_path
):
if
'rdkit'
not
in
sys
.
modules
:
if
'rdkit'
not
in
sys
.
modules
:
from
...base
import
dgl_warning
from
...base
import
dgl_warning
dgl_warning
(
dgl_warning
(
"Please install RDKit (Recommended Version is 2018.09.3)"
)
"Please install RDKit (Recommended Version is 2018.09.3)"
)
self
.
df
=
df
self
.
df
=
df
self
.
smiles
=
self
.
df
[
smile_column
].
tolist
()
self
.
smiles
=
self
.
df
[
smile
s
_column
].
tolist
()
self
.
task_names
=
self
.
df
.
columns
.
drop
([
smile_column
]).
tolist
()
self
.
task_names
=
self
.
df
.
columns
.
drop
([
smile
s
_column
]).
tolist
()
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
cache_file_path
=
cache_file_path
self
.
cache_file_path
=
cache_file_path
self
.
_pre_process
(
smile_to_graph
)
self
.
_pre_process
(
smile
s
_to_graph
,
atom_featurizer
,
bond_featurizer
)
def
_pre_process
(
self
,
smile_to_graph
):
def
_pre_process
(
self
,
smile
s
_to_graph
,
atom_featurizer
,
bond_featurizer
):
"""Pre-process the dataset
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
* Convert molecules from smiles format into DGLGraphs
...
@@ -58,8 +61,14 @@ class MoleculeCSVDataset(object):
...
@@ -58,8 +61,14 @@ class MoleculeCSVDataset(object):
Parameters
Parameters
----------
----------
smile_to_graph : callable, SMILES -> DGLGraph
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
"""
"""
if
os
.
path
.
exists
(
self
.
cache_file_path
):
if
os
.
path
.
exists
(
self
.
cache_file_path
):
# DGLGraphs have been constructed before, reload them
# DGLGraphs have been constructed before, reload them
...
@@ -72,7 +81,8 @@ class MoleculeCSVDataset(object):
...
@@ -72,7 +81,8 @@ class MoleculeCSVDataset(object):
self
.
graphs
=
[]
self
.
graphs
=
[]
for
i
,
s
in
enumerate
(
self
.
smiles
):
for
i
,
s
in
enumerate
(
self
.
smiles
):
print
(
'Processing molecule {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
print
(
'Processing molecule {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
self
.
graphs
.
append
(
smile_to_graph
(
s
))
self
.
graphs
.
append
(
smiles_to_graph
(
s
,
atom_featurizer
=
atom_featurizer
,
bond_featurizer
=
bond_featurizer
))
_label_values
=
self
.
df
[
self
.
task_names
].
values
_label_values
=
self
.
df
[
self
.
task_names
].
values
# np.nan_to_num will also turn inf into a very large number
# np.nan_to_num will also turn inf into a very large number
self
.
labels
=
F
.
zerocopy_from_numpy
(
np
.
nan_to_num
(
_label_values
).
astype
(
np
.
float32
))
self
.
labels
=
F
.
zerocopy_from_numpy
(
np
.
nan_to_num
(
_label_values
).
astype
(
np
.
float32
))
...
...
python/dgl/data/chem/tox21.py
View file @
577cf2e6
...
@@ -2,7 +2,7 @@ import numpy as np
...
@@ -2,7 +2,7 @@ import numpy as np
import
sys
import
sys
from
.csv_dataset
import
MoleculeCSVDataset
from
.csv_dataset
import
MoleculeCSVDataset
from
.utils
import
smile_to_bigraph
from
.utils
import
smile
s
_to_bigraph
from
..utils
import
get_download_dir
,
download
,
_get_dgl_url
from
..utils
import
get_download_dir
,
download
,
_get_dgl_url
from
...
import
backend
as
F
from
...
import
backend
as
F
...
@@ -30,11 +30,19 @@ class Tox21(MoleculeCSVDataset):
...
@@ -30,11 +30,19 @@ class Tox21(MoleculeCSVDataset):
Parameters
Parameters
----------
----------
smile_to_graph: callable, str -> DGLGraph
smiles_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
A function turning smiles into a DGLGraph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
"""
"""
def
__init__
(
self
,
smile_to_graph
=
smile_to_bigraph
):
def
__init__
(
self
,
smiles_to_graph
=
smiles_to_bigraph
,
atom_featurizer
=
None
,
bond_featurizer
=
None
):
if
'pandas'
not
in
sys
.
modules
:
if
'pandas'
not
in
sys
.
modules
:
from
...base
import
dgl_warning
from
...base
import
dgl_warning
dgl_warning
(
"Please install pandas"
)
dgl_warning
(
"Please install pandas"
)
...
@@ -47,10 +55,10 @@ class Tox21(MoleculeCSVDataset):
...
@@ -47,10 +55,10 @@ class Tox21(MoleculeCSVDataset):
df
=
df
.
drop
(
columns
=
[
'mol_id'
])
df
=
df
.
drop
(
columns
=
[
'mol_id'
])
super
().
__init__
(
df
,
smile_to_graph
,
cache_file_path
=
"tox21_dglgraph.bin"
)
super
(
Tox21
,
self
).
__init__
(
df
,
smiles_to_graph
,
atom_featurizer
,
bond_featurizer
,
"smiles"
,
"tox21_dglgraph.bin"
)
self
.
_weight_balancing
()
self
.
_weight_balancing
()
def
_weight_balancing
(
self
):
def
_weight_balancing
(
self
):
"""Perform re-balancing for each task.
"""Perform re-balancing for each task.
...
@@ -72,7 +80,6 @@ class Tox21(MoleculeCSVDataset):
...
@@ -72,7 +80,6 @@ class Tox21(MoleculeCSVDataset):
num_indices
=
F
.
sum
(
self
.
mask
,
dim
=
0
)
num_indices
=
F
.
sum
(
self
.
mask
,
dim
=
0
)
self
.
_task_pos_weights
=
(
num_indices
-
num_pos
)
/
num_pos
self
.
_task_pos_weights
=
(
num_indices
-
num_pos
)
/
num_pos
@
property
@
property
def
task_pos_weights
(
self
):
def
task_pos_weights
(
self
):
"""Get weights for positive samples on each task
"""Get weights for positive samples on each task
...
...
python/dgl/data/chem/utils.py
View file @
577cf2e6
import
dgl.backend
as
F
import
dgl.backend
as
F
import
itertools
import
numpy
as
np
import
numpy
as
np
from
functools
import
partial
from
functools
import
partial
from
collections
import
defaultdict
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
try
:
try
:
...
@@ -10,42 +12,568 @@ try:
...
@@ -10,42 +12,568 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
__all__
=
[
'one_hot_encoding'
,
'BaseAtomFeaturizer'
,
'CanonicalAtomFeaturizer'
,
__all__
=
[
'one_hot_encoding'
,
'atom_type_one_hot'
,
'atomic_number_one_hot'
,
'atomic_number'
,
'mol_to_graph'
,
'smile_to_bigraph'
,
'mol_to_bigraph'
,
'atom_degree_one_hot'
,
'atom_degree'
,
'atom_total_degree_one_hot'
,
'atom_total_degree'
,
'smile_to_complete_graph'
,
'mol_to_complete_graph'
]
'atom_implicit_valence_one_hot'
,
'atom_implicit_valence'
,
'atom_hybridization_one_hot'
,
'atom_total_num_H_one_hot'
,
'atom_total_num_H'
,
'atom_formal_charge_one_hot'
,
'atom_formal_charge'
,
'atom_num_radical_electrons_one_hot'
,
'atom_num_radical_electrons'
,
'atom_is_aromatic_one_hot'
,
'atom_is_aromatic'
,
'atom_chiral_tag_one_hot'
,
'atom_mass'
,
'ConcatFeaturizer'
,
'BaseAtomFeaturizer'
,
'CanonicalAtomFeaturizer'
,
'mol_to_graph'
,
'smiles_to_bigraph'
,
'mol_to_bigraph'
,
'smiles_to_complete_graph'
,
'mol_to_complete_graph'
,
'bond_type_one_hot'
,
'bond_is_conjugated_one_hot'
,
'bond_is_conjugated'
,
'bond_is_in_ring_one_hot'
,
'bond_is_in_ring'
,
'bond_stereo_one_hot'
,
'BaseBondFeaturizer'
,
'CanonicalBondFeaturizer'
]
def
one_hot_encoding
(
x
,
allowable_set
):
def
one_hot_encoding
(
x
,
allowable_set
,
encode_unknown
=
False
):
"""One-hot encoding.
"""One-hot encoding.
Parameters
Parameters
----------
----------
x : str, int or Chem.rdchem.HybridizationType
x
Value to encode.
allowable_set : list
allowable_set : list
The elements of the allowable_set should be of the
The elements of the allowable_set should be of the
same type as x.
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
Returns
-------
-------
list
list
List of boolean values where at most one value is True.
List of boolean values where at most one value is True.
If the i-th value is True, then we must have
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
x ==
allowable_set
[i]
.
and ``len(
allowable_set
) + 1`` otherwise
.
"""
"""
if
encode_unknown
:
allowable_set
.
append
(
None
)
if
x
not
in
allowable_set
:
x
=
None
return
list
(
map
(
lambda
s
:
x
==
s
,
allowable_set
))
return
list
(
map
(
lambda
s
:
x
==
s
,
allowable_set
))
#################################################################
# Atom featurization
#################################################################
def
atom_type_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
]
return
one_hot_encoding
(
atom
.
GetSymbol
(),
allowable_set
,
encode_unknown
)
def
atomic_number_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
101
))
return
one_hot_encoding
(
atom
.
GetAtomicNum
(),
allowable_set
,
encode_unknown
)
def
atomic_number
(
atom
):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetAtomicNum
()]
def
atom_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_total_degree_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
11
))
return
one_hot_encoding
(
atom
.
GetDegree
(),
allowable_set
,
encode_unknown
)
def
atom_degree
(
atom
):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree
"""
return
[
atom
.
GetDegree
()]
def
atom_total_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
atom_degree_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
6
))
return
one_hot_encoding
(
atom
.
GetTotalDegree
(),
allowable_set
,
encode_unknown
)
def
atom_total_degree
(
atom
):
"""
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetTotalDegree
()]
def
atom_implicit_valence_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the implicit valences of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
7
))
return
one_hot_encoding
(
atom
.
GetImplicitValence
(),
allowable_set
,
encode_unknown
)
def
atom_implicit_valence
(
atom
):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
"""
return
[
atom
.
GetImplicitValence
()]
def
atom_hybridization_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
,
Chem
.
rdchem
.
HybridizationType
.
SP3D
,
Chem
.
rdchem
.
HybridizationType
.
SP3D2
]
return
one_hot_encoding
(
atom
.
GetHybridization
(),
allowable_set
,
encode_unknown
)
def
atom_total_num_H_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetTotalNumHs
(),
allowable_set
,
encode_unknown
)
def
atom_total_num_H
(
atom
):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetTotalNumHs
()]
def
atom_formal_charge_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
-
2
,
3
))
return
one_hot_encoding
(
atom
.
GetFormalCharge
(),
allowable_set
,
encode_unknown
)
def
atom_formal_charge
(
atom
):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetFormalCharge
()]
def
atom_num_radical_electrons_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetNumRadicalElectrons
(),
allowable_set
,
encode_unknown
)
def
atom_num_radical_electrons
(
atom
):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetNumRadicalElectrons
()]
def
atom_is_aromatic_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
atom
.
GetIsAromatic
(),
allowable_set
,
encode_unknown
)
def
atom_is_aromatic
(
atom
):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return
[
atom
.
GetIsAromatic
()]
def
atom_chiral_tag_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CW
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CCW
,
Chem
.
rdchem
.
ChiralType
.
CHI_OTHER
]
return
one_hot_encoding
(
atom
.
GetChiralTag
(),
allowable_set
,
encode_unknown
)
def
atom_mass
(
atom
,
coef
=
0.01
):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return
[
atom
.
GetMass
()
*
coef
]
class
ConcatFeaturizer
(
object
):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
def
__init__
(
self
,
func_list
):
self
.
func_list
=
func_list
def
__call__
(
self
,
x
):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return
list
(
itertools
.
chain
.
from_iterable
(
[
func
(
x
)
for
func
in
self
.
func_list
]))
class
BaseAtomFeaturizer
(
object
):
class
BaseAtomFeaturizer
(
object
):
"""An abstract class for atom featurizers
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
All atom featurizers that map a molecule to atom features should subclass it.
>>> from dgl.data.chem import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
All subclasses should overwrite ``_featurize_atom``, which featurizes a single
>>> from rdkit import Chem
atom and ``__call__``, which featurizes all atoms in a molecule.
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
"""
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
if
feat_name
not
in
self
.
_feat_sizes
:
atom
=
Chem
.
MolFromSmiles
(
'C'
).
GetAtomWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
atom
))
def
_featurize_atom
(
self
,
atom
):
return
self
.
_feat_sizes
[
feat_name
]
return
NotImplementedError
def
__call__
(
self
,
mol
):
def
__call__
(
self
,
mol
):
return
NotImplementedError
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms
=
mol
.
GetNumAtoms
()
atom_features
=
defaultdict
(
list
)
# Compute features for each atom
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
atom_features
[
feat_name
].
append
(
feat_func
(
atom
))
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
atom_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
class
CanonicalAtomFeaturizer
(
BaseAtomFeaturizer
):
class
CanonicalAtomFeaturizer
(
BaseAtomFeaturizer
):
"""A default featurizer for atoms.
"""A default featurizer for atoms.
...
@@ -70,55 +598,207 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
...
@@ -70,55 +598,207 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
Parameters
----------
----------
atom_data_field : str
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
"""
def
__init__
(
self
,
atom_data_field
=
'h'
):
def
__init__
(
self
,
atom_data_field
=
'h'
):
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
()
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
(
self
.
atom_data_field
=
atom_data_field
featurizer_funcs
=
{
atom_data_field
:
ConcatFeaturizer
(
[
atom_type_one_hot
,
atom_degree_one_hot
,
atom_implicit_valence_one_hot
,
atom_formal_charge
,
atom_num_radical_electrons
,
atom_hybridization_one_hot
,
atom_is_aromatic
,
atom_total_num_H_one_hot
]
)})
@
property
def
bond_type_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
def
feat_size
(
self
):
"""One hot encoding for the type of a bond.
"""Returns feature size"""
return
74
def
_featurize_atom
(
self
,
atom
):
Parameters
"""Featurize an atom
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
]
return
one_hot_encoding
(
bond
.
GetBondType
(),
allowable_set
,
encode_unknown
)
def
bond_is_conjugated_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is conjugated.
Parameters
Parameters
----------
----------
atom : rdkit.Chem.rdchem.Atom
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
GetIsConjugated
(),
allowable_set
,
encode_unknown
)
def
bond_is_conjugated
(
bond
):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
Returns
-------
-------
results :
list
list
List
of feature values, including boolean values and numbers
List
containing one bool only.
"""
"""
atom_types
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
return
[
bond
.
GetIsConjugated
()]
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
def
bond_is_in_ring_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
"""One hot encoding for whether the bond is in a ring of any size.
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
]
Parameters
results
=
one_hot_encoding
(
atom
.
GetSymbol
(),
atom_types
)
+
\
----------
one_hot_encoding
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
+
\
bond : rdkit.Chem.rdchem.Bond
one_hot_encoding
(
atom
.
GetImplicitValence
(),
[
0
,
1
,
2
,
3
,
4
,
5
,
6
])
+
\
RDKit bond instance.
[
atom
.
GetFormalCharge
(),
atom
.
GetNumRadicalElectrons
()]
+
\
allowable_set : list of bool
one_hot_encoding
(
atom
.
GetHybridization
(),
Conditions to consider. Default: ``False`` and ``True``.
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
encode_unknown : bool
Chem
.
rdchem
.
HybridizationType
.
SP2
,
If True, map inputs not in the allowable set to the
Chem
.
rdchem
.
HybridizationType
.
SP3
,
additional last element. (Default: False)
Chem
.
rdchem
.
HybridizationType
.
SP3D
,
Returns
Chem
.
rdchem
.
HybridizationType
.
SP3D2
])
+
\
-------
[
atom
.
GetIsAromatic
()]
+
\
list
one_hot_encoding
(
atom
.
GetTotalNumHs
(),
[
0
,
1
,
2
,
3
,
4
])
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
IsInRing
(),
allowable_set
,
encode_unknown
)
def
bond_is_in_ring
(
bond
):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return
[
bond
.
IsInRing
()]
def
bond_stereo_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondStereo
.
STEREONONE
,
Chem
.
rdchem
.
BondStereo
.
STEREOANY
,
Chem
.
rdchem
.
BondStereo
.
STEREOZ
,
Chem
.
rdchem
.
BondStereo
.
STEREOE
,
Chem
.
rdchem
.
BondStereo
.
STEREOCIS
,
Chem
.
rdchem
.
BondStereo
.
STEREOTRANS
]
return
one_hot_encoding
(
bond
.
GetStereo
(),
allowable_set
,
encode_unknown
)
class
BaseBondFeaturizer
(
object
):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
return
results
if
feat_name
not
in
self
.
_feat_sizes
:
bond
=
Chem
.
MolFromSmiles
(
'CO'
).
GetBondWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
bond
))
return
self
.
_feat_sizes
[
feat_name
]
def
__call__
(
self
,
mol
):
def
__call__
(
self
,
mol
):
"""Featurize a molecule
"""Featurize
all bonds in
a molecule
.
Parameters
Parameters
----------
----------
...
@@ -128,18 +808,55 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
...
@@ -128,18 +808,55 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
Returns
Returns
-------
-------
dict
dict
Atom features of shape (N, 74),
For each function in self.featurizer_funcs with the key ``k``, store the computed
where N is the number of atoms in the molecule
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
"""
num_atoms
=
mol
.
GetNumAtoms
()
num_bonds
=
mol
.
GetNumBonds
()
atom_features
=
[]
bond_features
=
defaultdict
(
list
)
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
# Compute features for each bond
atom_features
.
append
(
self
.
_featurize_atom
(
atom
))
for
i
in
range
(
num_bonds
):
atom_features
=
np
.
stack
(
atom_features
)
bond
=
mol
.
GetBondWithIdx
(
i
)
atom_features
=
F
.
zerocopy_from_numpy
(
atom_features
.
astype
(
np
.
float32
))
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
feat
=
feat_func
(
bond
)
bond_features
[
feat_name
].
extend
([
feat
,
feat
.
copy
()])
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
bond_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
return
{
self
.
atom_data_field
:
atom_features
}
class
CanonicalBondFeaturizer
(
BaseBondFeaturizer
):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
def
__init__
(
self
,
bond_data_field
=
'e'
):
super
(
CanonicalBondFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
bond_data_field
:
ConcatFeaturizer
(
[
bond_type_one_hot
,
bond_is_conjugated
,
bond_is_in_ring
,
bond_stereo_one_hot
]
)})
#################################################################
# DGLGraph Construction
#################################################################
def
mol_to_graph
(
mol
,
graph_constructor
,
atom_featurizer
,
bond_featurizer
):
def
mol_to_graph
(
mol
,
graph_constructor
,
atom_featurizer
,
bond_featurizer
):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
...
@@ -193,7 +910,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
...
@@ -193,7 +910,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
RDKit molecule holder
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
Returns
Returns
-------
-------
...
@@ -225,7 +942,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
...
@@ -225,7 +942,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
return
g
return
g
def
mol_to_bigraph
(
mol
,
add_self_loop
=
False
,
def
mol_to_bigraph
(
mol
,
add_self_loop
=
False
,
atom_featurizer
=
CanonicalAtomFeaturizer
()
,
atom_featurizer
=
None
,
bond_featurizer
=
None
):
bond_featurizer
=
None
):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
...
@@ -234,13 +951,13 @@ def mol_to_bigraph(mol, add_self_loop=False,
...
@@ -234,13 +951,13 @@ def mol_to_bigraph(mol, add_self_loop=False,
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
RDKit molecule holder
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to
CanonicalAtomFeaturizer()
.
ndata for a DGLGraph. Default to
None
.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph.
Default to None.
Returns
Returns
-------
-------
...
@@ -250,30 +967,30 @@ def mol_to_bigraph(mol, add_self_loop=False,
...
@@ -250,30 +967,30 @@ def mol_to_bigraph(mol, add_self_loop=False,
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
atom_featurizer
,
bond_featurizer
)
atom_featurizer
,
bond_featurizer
)
def
smile_to_bigraph
(
smile
,
add_self_loop
=
False
,
def
smile
s
_to_bigraph
(
smile
s
,
add_self_loop
=
False
,
atom_featurizer
=
CanonicalAtomFeaturizer
()
,
atom_featurizer
=
None
,
bond_featurizer
=
None
):
bond_featurizer
=
None
):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
Parameters
----------
----------
smile : str
smile
s
: str
String of SMILES
String of SMILES
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to
CanonicalAtomFeaturizer()
.
ndata for a DGLGraph. Default to
None
.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph.
Default to None.
Returns
Returns
-------
-------
g : DGLGraph
g : DGLGraph
Bi-directed DGLGraph for the molecule
Bi-directed DGLGraph for the molecule
"""
"""
mol
=
Chem
.
MolFromSmiles
(
smile
)
mol
=
Chem
.
MolFromSmiles
(
smile
s
)
return
mol_to_bigraph
(
mol
,
add_self_loop
,
atom_featurizer
,
bond_featurizer
)
return
mol_to_bigraph
(
mol
,
add_self_loop
,
atom_featurizer
,
bond_featurizer
)
def
construct_complete_graph_from_mol
(
mol
,
add_self_loop
=
False
):
def
construct_complete_graph_from_mol
(
mol
,
add_self_loop
=
False
):
...
@@ -290,7 +1007,7 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
...
@@ -290,7 +1007,7 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
RDKit molecule holder
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
Returns
Returns
-------
-------
...
@@ -324,13 +1041,13 @@ def mol_to_complete_graph(mol, add_self_loop=False,
...
@@ -324,13 +1041,13 @@ def mol_to_complete_graph(mol, add_self_loop=False,
mol : rdkit.Chem.rdchem.Mol
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
RDKit molecule holder
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to
CanonicalAtomFeaturizer()
.
ndata for a DGLGraph. Default to
None
.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph.
Default to None.
Returns
Returns
-------
-------
...
@@ -340,28 +1057,28 @@ def mol_to_complete_graph(mol, add_self_loop=False,
...
@@ -340,28 +1057,28 @@ def mol_to_complete_graph(mol, add_self_loop=False,
return
mol_to_graph
(
mol
,
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
return
mol_to_graph
(
mol
,
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
atom_featurizer
,
bond_featurizer
)
atom_featurizer
,
bond_featurizer
)
def
smile_to_complete_graph
(
smile
,
add_self_loop
=
False
,
def
smile
s
_to_complete_graph
(
smile
s
,
add_self_loop
=
False
,
atom_featurizer
=
None
,
atom_featurizer
=
None
,
bond_featurizer
=
None
):
bond_featurizer
=
None
):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
Parameters
----------
----------
smile : str
smile
s
: str
String of SMILES
String of SMILES
add_self_loop : bool
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs.
Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to
CanonicalAtomFeaturizer()
.
ndata for a DGLGraph. Default to
None
.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph.
Default to None.
Returns
Returns
-------
-------
g : DGLGraph
g : DGLGraph
Complete DGLGraph for the molecule
Complete DGLGraph for the molecule
"""
"""
mol
=
Chem
.
MolFromSmiles
(
smile
)
mol
=
Chem
.
MolFromSmiles
(
smile
s
)
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
atom_featurizer
,
bond_featurizer
)
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
atom_featurizer
,
bond_featurizer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment