Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
4121 deletions
+0
-4121
examples/pytorch/model_zoo/chem/property_prediction/classification.py
...orch/model_zoo/chem/property_prediction/classification.py
+0
-105
examples/pytorch/model_zoo/chem/property_prediction/configure.py
...s/pytorch/model_zoo/chem/property_prediction/configure.py
+0
-127
examples/pytorch/model_zoo/chem/property_prediction/regression.py
.../pytorch/model_zoo/chem/property_prediction/regression.py
+0
-124
examples/pytorch/model_zoo/chem/property_prediction/utils.py
examples/pytorch/model_zoo/chem/property_prediction/utils.py
+0
-357
examples/pytorch/model_zoo/chem/requirements.txt
examples/pytorch/model_zoo/chem/requirements.txt
+0
-3
python/dgl/contrib/deprecation.py
python/dgl/contrib/deprecation.py
+0
-39
python/dgl/data/chem/README.md
python/dgl/data/chem/README.md
+0
-34
python/dgl/data/chem/__init__.py
python/dgl/data/chem/__init__.py
+0
-2
python/dgl/data/chem/datasets/__init__.py
python/dgl/data/chem/datasets/__init__.py
+0
-5
python/dgl/data/chem/datasets/alchemy.py
python/dgl/data/chem/datasets/alchemy.py
+0
-307
python/dgl/data/chem/datasets/csv_dataset.py
python/dgl/data/chem/datasets/csv_dataset.py
+0
-137
python/dgl/data/chem/datasets/pdbbind.py
python/dgl/data/chem/datasets/pdbbind.py
+0
-295
python/dgl/data/chem/datasets/pubchem_aromaticity.py
python/dgl/data/chem/datasets/pubchem_aromaticity.py
+0
-57
python/dgl/data/chem/datasets/tox21.py
python/dgl/data/chem/datasets/tox21.py
+0
-107
python/dgl/data/chem/utils/__init__.py
python/dgl/data/chem/utils/__init__.py
+0
-5
python/dgl/data/chem/utils/complex_to_graph.py
python/dgl/data/chem/utils/complex_to_graph.py
+0
-233
python/dgl/data/chem/utils/featurizers.py
python/dgl/data/chem/utils/featurizers.py
+0
-910
python/dgl/data/chem/utils/mol_to_graph.py
python/dgl/data/chem/utils/mol_to_graph.py
+0
-297
python/dgl/data/chem/utils/rdkit_utils.py
python/dgl/data/chem/utils/rdkit_utils.py
+0
-220
python/dgl/data/chem/utils/splitters.py
python/dgl/data/chem/utils/splitters.py
+0
-757
No files found.
examples/pytorch/model_zoo/chem/property_prediction/classification.py
deleted
100644 → 0
View file @
94c67203
import
numpy
as
np
import
torch
from
torch.nn
import
BCEWithLogitsLoss
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
dgl
import
model_zoo
from
utils
import
Meter
,
EarlyStopping
,
collate_molgraphs
,
set_random_seed
,
\
load_dataset_for_classification
,
load_model
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
loss_criterion
,
optimizer
):
model
.
train
()
train_meter
=
Meter
()
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
masks
=
batch_data
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
,
labels
,
masks
=
atom_feats
.
to
(
args
[
'device'
]),
\
labels
.
to
(
args
[
'device'
]),
\
masks
.
to
(
args
[
'device'
])
logits
=
model
(
bg
,
atom_feats
)
# Mask non-existing labels
loss
=
(
loss_criterion
(
logits
,
labels
)
*
(
masks
!=
0
).
float
()).
mean
()
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
print
(
'epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
batch_id
+
1
,
len
(
data_loader
),
loss
.
item
()))
train_meter
.
update
(
logits
,
labels
,
masks
)
train_score
=
np
.
mean
(
train_meter
.
compute_metric
(
args
[
'metric_name'
]))
print
(
'epoch {:d}/{:d}, training {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
train_score
))
def
run_an_eval_epoch
(
args
,
model
,
data_loader
):
model
.
eval
()
eval_meter
=
Meter
()
with
torch
.
no_grad
():
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
masks
=
batch_data
atom_feats
=
bg
.
ndata
.
pop
(
args
[
'atom_data_field'
])
atom_feats
,
labels
=
atom_feats
.
to
(
args
[
'device'
]),
labels
.
to
(
args
[
'device'
])
logits
=
model
(
bg
,
atom_feats
)
eval_meter
.
update
(
logits
,
labels
,
masks
)
return
np
.
mean
(
eval_meter
.
compute_metric
(
args
[
'metric_name'
]))
def
main
(
args
):
args
[
'device'
]
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
set_random_seed
(
args
[
'random_seed'
])
# Interchangeable with other datasets
dataset
,
train_set
,
val_set
,
test_set
=
load_dataset_for_classification
(
args
)
train_loader
=
DataLoader
(
train_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
)
test_loader
=
DataLoader
(
test_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
)
if
args
[
'pre_trained'
]:
args
[
'num_epochs'
]
=
0
model
=
model_zoo
.
chem
.
load_pretrained
(
args
[
'exp'
])
else
:
args
[
'n_tasks'
]
=
dataset
.
n_tasks
model
=
load_model
(
args
)
loss_criterion
=
BCEWithLogitsLoss
(
pos_weight
=
dataset
.
task_pos_weights
.
to
(
args
[
'device'
]),
reduction
=
'none'
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
args
[
'lr'
])
stopper
=
EarlyStopping
(
patience
=
args
[
'patience'
])
model
.
to
(
args
[
'device'
])
for
epoch
in
range
(
args
[
'num_epochs'
]):
# Train
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_criterion
,
optimizer
)
# Validation and early stop
val_score
=
run_an_eval_epoch
(
args
,
model
,
val_loader
)
early_stop
=
stopper
.
step
(
val_score
,
model
)
print
(
'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
val_score
,
args
[
'metric_name'
],
stopper
.
best_score
))
if
early_stop
:
break
if
not
args
[
'pre_trained'
]:
stopper
.
load_checkpoint
(
model
)
test_score
=
run_an_eval_epoch
(
args
,
model
,
test_loader
)
print
(
'test {} {:.4f}'
.
format
(
args
[
'metric_name'
],
test_score
))
if
__name__
==
'__main__'
:
import
argparse
from
configure
import
get_exp_configure
parser
=
argparse
.
ArgumentParser
(
description
=
'Molecule Classification'
)
parser
.
add_argument
(
'-m'
,
'--model'
,
type
=
str
,
choices
=
[
'GCN'
,
'GAT'
],
help
=
'Model to use'
)
parser
.
add_argument
(
'-d'
,
'--dataset'
,
type
=
str
,
choices
=
[
'Tox21'
],
help
=
'Dataset to use'
)
parser
.
add_argument
(
'-p'
,
'--pre-trained'
,
action
=
'store_true'
,
help
=
'Whether to skip training and use a pre-trained model'
)
args
=
parser
.
parse_args
().
__dict__
args
[
'exp'
]
=
'_'
.
join
([
args
[
'model'
],
args
[
'dataset'
]])
args
.
update
(
get_exp_configure
(
args
[
'exp'
]))
main
(
args
)
examples/pytorch/model_zoo/chem/property_prediction/configure.py
deleted
100644 → 0
View file @
94c67203
from
dgl.data.chem
import
BaseAtomFeaturizer
,
CanonicalAtomFeaturizer
,
ConcatFeaturizer
,
\
atom_type_one_hot
,
atom_degree_one_hot
,
atom_formal_charge
,
atom_num_radical_electrons
,
\
atom_hybridization_one_hot
,
atom_total_num_H_one_hot
,
BaseBondFeaturizer
from
functools
import
partial
from
utils
import
chirality
GCN_Tox21
=
{
'random_seed'
:
0
,
'batch_size'
:
128
,
'lr'
:
1e-3
,
'num_epochs'
:
100
,
'atom_data_field'
:
'h'
,
'frac_train'
:
0.8
,
'frac_val'
:
0.1
,
'frac_test'
:
0.1
,
'in_feats'
:
74
,
'gcn_hidden_feats'
:
[
64
,
64
],
'classifier_hidden_feats'
:
64
,
'patience'
:
10
,
'atom_featurizer'
:
CanonicalAtomFeaturizer
(),
'metric_name'
:
'roc_auc'
}
GAT_Tox21
=
{
'random_seed'
:
0
,
'batch_size'
:
128
,
'lr'
:
1e-3
,
'num_epochs'
:
100
,
'atom_data_field'
:
'h'
,
'frac_train'
:
0.8
,
'frac_val'
:
0.1
,
'frac_test'
:
0.1
,
'in_feats'
:
74
,
'gat_hidden_feats'
:
[
32
,
32
],
'classifier_hidden_feats'
:
64
,
'num_heads'
:
[
4
,
4
],
'patience'
:
10
,
'atom_featurizer'
:
CanonicalAtomFeaturizer
(),
'metric_name'
:
'roc_auc'
}
MPNN_Alchemy
=
{
'random_seed'
:
0
,
'batch_size'
:
16
,
'num_epochs'
:
250
,
'node_in_feats'
:
15
,
'edge_in_feats'
:
5
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'patience'
:
50
,
'metric_name'
:
'l1'
,
'weight_decay'
:
0
}
SCHNET_Alchemy
=
{
'random_seed'
:
0
,
'batch_size'
:
16
,
'num_epochs'
:
250
,
'norm'
:
True
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'patience'
:
50
,
'metric_name'
:
'l1'
,
'weight_decay'
:
0
}
MGCN_Alchemy
=
{
'random_seed'
:
0
,
'batch_size'
:
16
,
'num_epochs'
:
250
,
'norm'
:
True
,
'output_dim'
:
12
,
'lr'
:
0.0001
,
'patience'
:
50
,
'metric_name'
:
'l1'
,
'weight_decay'
:
0
}
AttentiveFP_Aromaticity
=
{
'random_seed'
:
8
,
'graph_feat_size'
:
200
,
'num_layers'
:
2
,
'num_timesteps'
:
2
,
'node_feat_size'
:
39
,
'edge_feat_size'
:
10
,
'output_size'
:
1
,
'dropout'
:
0.2
,
'weight_decay'
:
10
**
(
-
5.0
),
'lr'
:
10
**
(
-
2.5
),
'batch_size'
:
128
,
'num_epochs'
:
800
,
'frac_train'
:
0.8
,
'frac_val'
:
0.1
,
'frac_test'
:
0.1
,
'patience'
:
80
,
'metric_name'
:
'rmse'
,
# Follow the atom featurization in the original work
'atom_featurizer'
:
BaseAtomFeaturizer
(
featurizer_funcs
=
{
'hv'
:
ConcatFeaturizer
([
partial
(
atom_type_one_hot
,
allowable_set
=
[
'B'
,
'C'
,
'N'
,
'O'
,
'F'
,
'Si'
,
'P'
,
'S'
,
'Cl'
,
'As'
,
'Se'
,
'Br'
,
'Te'
,
'I'
,
'At'
],
encode_unknown
=
True
),
partial
(
atom_degree_one_hot
,
allowable_set
=
list
(
range
(
6
))),
atom_formal_charge
,
atom_num_radical_electrons
,
partial
(
atom_hybridization_one_hot
,
encode_unknown
=
True
),
lambda
atom
:
[
0
],
# A placeholder for aromatic information,
atom_total_num_H_one_hot
,
chirality
],
)}
),
'bond_featurizer'
:
BaseBondFeaturizer
({
'he'
:
lambda
bond
:
[
0
for
_
in
range
(
10
)]
})
}
experiment_configures
=
{
'GCN_Tox21'
:
GCN_Tox21
,
'GAT_Tox21'
:
GAT_Tox21
,
'MPNN_Alchemy'
:
MPNN_Alchemy
,
'SCHNET_Alchemy'
:
SCHNET_Alchemy
,
'MGCN_Alchemy'
:
MGCN_Alchemy
,
'AttentiveFP_Aromaticity'
:
AttentiveFP_Aromaticity
}
def
get_exp_configure
(
exp_name
):
return
experiment_configures
[
exp_name
]
examples/pytorch/model_zoo/chem/property_prediction/regression.py
deleted
100644 → 0
View file @
94c67203
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
dgl
import
model_zoo
from
utils
import
Meter
,
set_random_seed
,
collate_molgraphs
,
EarlyStopping
,
\
load_dataset_for_regression
,
load_model
def
regress
(
args
,
model
,
bg
):
if
args
[
'model'
]
==
'MPNN'
:
h
=
bg
.
ndata
.
pop
(
'n_feat'
)
e
=
bg
.
edata
.
pop
(
'e_feat'
)
h
,
e
=
h
.
to
(
args
[
'device'
]),
e
.
to
(
args
[
'device'
])
return
model
(
bg
,
h
,
e
)
elif
args
[
'model'
]
in
[
'SCHNET'
,
'MGCN'
]:
node_types
=
bg
.
ndata
.
pop
(
'node_type'
)
edge_distances
=
bg
.
edata
.
pop
(
'distance'
)
node_types
,
edge_distances
=
node_types
.
to
(
args
[
'device'
]),
\
edge_distances
.
to
(
args
[
'device'
])
return
model
(
bg
,
node_types
,
edge_distances
)
else
:
atom_feats
,
bond_feats
=
bg
.
ndata
.
pop
(
'hv'
),
bg
.
edata
.
pop
(
'he'
)
atom_feats
,
bond_feats
=
atom_feats
.
to
(
args
[
'device'
]),
bond_feats
.
to
(
args
[
'device'
])
return
model
(
bg
,
atom_feats
,
bond_feats
)
def
run_a_train_epoch
(
args
,
epoch
,
model
,
data_loader
,
loss_criterion
,
optimizer
):
model
.
train
()
train_meter
=
Meter
()
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
masks
=
batch_data
labels
,
masks
=
labels
.
to
(
args
[
'device'
]),
masks
.
to
(
args
[
'device'
])
prediction
=
regress
(
args
,
model
,
bg
)
loss
=
(
loss_criterion
(
prediction
,
labels
)
*
(
masks
!=
0
).
float
()).
mean
()
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
train_meter
.
update
(
prediction
,
labels
,
masks
)
total_score
=
np
.
mean
(
train_meter
.
compute_metric
(
args
[
'metric_name'
]))
print
(
'epoch {:d}/{:d}, training {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
total_score
))
def
run_an_eval_epoch
(
args
,
model
,
data_loader
):
model
.
eval
()
eval_meter
=
Meter
()
with
torch
.
no_grad
():
for
batch_id
,
batch_data
in
enumerate
(
data_loader
):
smiles
,
bg
,
labels
,
masks
=
batch_data
labels
=
labels
.
to
(
args
[
'device'
])
prediction
=
regress
(
args
,
model
,
bg
)
eval_meter
.
update
(
prediction
,
labels
,
masks
)
total_score
=
np
.
mean
(
eval_meter
.
compute_metric
(
args
[
'metric_name'
]))
return
total_score
def
main
(
args
):
args
[
'device'
]
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
set_random_seed
(
args
[
'random_seed'
])
train_set
,
val_set
,
test_set
=
load_dataset_for_regression
(
args
)
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
args
[
'batch_size'
],
shuffle
=
True
,
collate_fn
=
collate_molgraphs
)
val_loader
=
DataLoader
(
dataset
=
val_set
,
batch_size
=
args
[
'batch_size'
],
shuffle
=
True
,
collate_fn
=
collate_molgraphs
)
if
test_set
is
not
None
:
test_loader
=
DataLoader
(
dataset
=
test_set
,
batch_size
=
args
[
'batch_size'
],
collate_fn
=
collate_molgraphs
)
if
args
[
'pre_trained'
]:
args
[
'num_epochs'
]
=
0
model
=
model_zoo
.
chem
.
load_pretrained
(
args
[
'exp'
])
else
:
model
=
load_model
(
args
)
if
args
[
'model'
]
in
[
'SCHNET'
,
'MGCN'
]:
model
.
set_mean_std
(
train_set
.
mean
,
train_set
.
std
,
args
[
'device'
])
loss_fn
=
nn
.
MSELoss
(
reduction
=
'none'
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
[
'lr'
],
weight_decay
=
args
[
'weight_decay'
])
stopper
=
EarlyStopping
(
mode
=
'lower'
,
patience
=
args
[
'patience'
])
model
.
to
(
args
[
'device'
])
for
epoch
in
range
(
args
[
'num_epochs'
]):
# Train
run_a_train_epoch
(
args
,
epoch
,
model
,
train_loader
,
loss_fn
,
optimizer
)
# Validation and early stop
val_score
=
run_an_eval_epoch
(
args
,
model
,
val_loader
)
early_stop
=
stopper
.
step
(
val_score
,
model
)
print
(
'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'
.
format
(
epoch
+
1
,
args
[
'num_epochs'
],
args
[
'metric_name'
],
val_score
,
args
[
'metric_name'
],
stopper
.
best_score
))
if
early_stop
:
break
if
test_set
is
not
None
:
if
not
args
[
'pre_trained'
]:
stopper
.
load_checkpoint
(
model
)
test_score
=
run_an_eval_epoch
(
args
,
model
,
test_loader
)
print
(
'test {} {:.4f}'
.
format
(
args
[
'metric_name'
],
test_score
))
if
__name__
==
"__main__"
:
import
argparse
from
configure
import
get_exp_configure
parser
=
argparse
.
ArgumentParser
(
description
=
'Molecule Regression'
)
parser
.
add_argument
(
'-m'
,
'--model'
,
type
=
str
,
choices
=
[
'MPNN'
,
'SCHNET'
,
'MGCN'
,
'AttentiveFP'
],
help
=
'Model to use'
)
parser
.
add_argument
(
'-d'
,
'--dataset'
,
type
=
str
,
choices
=
[
'Alchemy'
,
'Aromaticity'
],
help
=
'Dataset to use'
)
parser
.
add_argument
(
'-p'
,
'--pre-trained'
,
action
=
'store_true'
,
help
=
'Whether to skip training and use a pre-trained model'
)
args
=
parser
.
parse_args
().
__dict__
args
[
'exp'
]
=
'_'
.
join
([
args
[
'model'
],
args
[
'dataset'
]])
args
.
update
(
get_exp_configure
(
args
[
'exp'
]))
main
(
args
)
examples/pytorch/model_zoo/chem/property_prediction/utils.py
deleted
100644 → 0
View file @
94c67203
import
datetime
import
dgl
import
numpy
as
np
import
random
import
torch
import
torch.nn.functional
as
F
from
dgl
import
model_zoo
from
dgl.data.chem
import
smiles_to_bigraph
,
one_hot_encoding
,
RandomSplitter
from
sklearn.metrics
import
roc_auc_score
def
set_random_seed
(
seed
=
0
):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
def
chirality
(
atom
):
try
:
return
one_hot_encoding
(
atom
.
GetProp
(
'_CIPCode'
),
[
'R'
,
'S'
])
+
\
[
atom
.
HasProp
(
'_ChiralityPossible'
)]
except
:
return
[
False
,
False
]
+
[
atom
.
HasProp
(
'_ChiralityPossible'
)]
class
Meter
(
object
):
"""Track and summarize model performance on a dataset for
(multi-label) binary classification."""
def
__init__
(
self
):
self
.
mask
=
[]
self
.
y_pred
=
[]
self
.
y_true
=
[]
def
update
(
self
,
y_pred
,
y_true
,
mask
):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted molecule labels with shape (B, T),
B for batch size and T for the number of tasks
y_true : float32 tensor
Ground truth molecule labels with shape (B, T)
mask : float32 tensor
Mask for indicating the existence of ground
truth labels with shape (B, T)
"""
self
.
y_pred
.
append
(
y_pred
.
detach
().
cpu
())
self
.
y_true
.
append
(
y_true
.
detach
().
cpu
())
self
.
mask
.
append
(
mask
.
detach
().
cpu
())
def
roc_auc_score
(
self
):
"""Compute roc-auc score for each task.
Returns
-------
list of float
roc-auc score for all tasks
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
# Todo: support categorical classes
# This assumes binary case only
y_pred
=
torch
.
sigmoid
(
y_pred
)
n_tasks
=
y_true
.
shape
[
1
]
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
].
numpy
()
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
].
numpy
()
scores
.
append
(
roc_auc_score
(
task_y_true
,
task_y_pred
))
return
scores
def
l1_loss
(
self
,
reduction
):
"""Compute l1 loss for each task.
Returns
-------
list of float
l1 loss for all tasks
reduction : str
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
n_tasks
=
y_true
.
shape
[
1
]
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
]
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
]
scores
.
append
(
F
.
l1_loss
(
task_y_true
,
task_y_pred
,
reduction
=
reduction
).
item
())
return
scores
def
rmse
(
self
):
"""Compute RMSE for each task.
Returns
-------
list of float
rmse for all tasks
"""
mask
=
torch
.
cat
(
self
.
mask
,
dim
=
0
)
y_pred
=
torch
.
cat
(
self
.
y_pred
,
dim
=
0
)
y_true
=
torch
.
cat
(
self
.
y_true
,
dim
=
0
)
n_data
,
n_tasks
=
y_true
.
shape
scores
=
[]
for
task
in
range
(
n_tasks
):
task_w
=
mask
[:,
task
]
task_y_true
=
y_true
[:,
task
][
task_w
!=
0
]
task_y_pred
=
y_pred
[:,
task
][
task_w
!=
0
]
scores
.
append
(
np
.
sqrt
(
F
.
mse_loss
(
task_y_pred
,
task_y_true
).
cpu
().
item
()))
return
scores
def
compute_metric
(
self
,
metric_name
,
reduction
=
'mean'
):
"""Compute metric for each task.
Parameters
----------
metric_name : str
Name for the metric to compute.
reduction : str
Only comes into effect when the metric_name is l1_loss.
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
Returns
-------
list of float
Metric value for each task
"""
assert
metric_name
in
[
'roc_auc'
,
'l1'
,
'rmse'
],
\
'Expect metric name to be "roc_auc", "l1" or "rmse", got {}'
.
format
(
metric_name
)
assert
reduction
in
[
'mean'
,
'sum'
]
if
metric_name
==
'roc_auc'
:
return
self
.
roc_auc_score
()
if
metric_name
==
'l1'
:
return
self
.
l1_loss
(
reduction
)
if
metric_name
==
'rmse'
:
return
self
.
rmse
()
class
EarlyStopping
(
object
):
"""Early stop performing
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
Number of epochs to wait before early stop
if the metric stops getting improved
filename : str or None
Filename for storing the model checkpoint
"""
def
__init__
(
self
,
mode
=
'higher'
,
patience
=
10
,
filename
=
None
):
if
filename
is
None
:
dt
=
datetime
.
datetime
.
now
()
filename
=
'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'
.
format
(
dt
.
date
(),
dt
.
hour
,
dt
.
minute
,
dt
.
second
)
assert
mode
in
[
'higher'
,
'lower'
]
self
.
mode
=
mode
if
self
.
mode
==
'higher'
:
self
.
_check
=
self
.
_check_higher
else
:
self
.
_check
=
self
.
_check_lower
self
.
patience
=
patience
self
.
counter
=
0
self
.
filename
=
filename
self
.
best_score
=
None
self
.
early_stop
=
False
def
_check_higher
(
self
,
score
,
prev_best_score
):
return
(
score
>
prev_best_score
)
def
_check_lower
(
self
,
score
,
prev_best_score
):
return
(
score
<
prev_best_score
)
def
step
(
self
,
score
,
model
):
if
self
.
best_score
is
None
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
elif
self
.
_check
(
score
,
self
.
best_score
):
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
self
.
counter
=
0
else
:
self
.
counter
+=
1
print
(
f
'EarlyStopping counter:
{
self
.
counter
}
out of
{
self
.
patience
}
'
)
if
self
.
counter
>=
self
.
patience
:
self
.
early_stop
=
True
return
self
.
early_stop
def
save_checkpoint
(
self
,
model
):
'''Saves model when the metric on the validation set gets improved.'''
torch
.
save
({
'model_state_dict'
:
model
.
state_dict
()},
self
.
filename
)
def
load_checkpoint
(
self
,
model
):
'''Load model saved with early stopping.'''
model
.
load_state_dict
(
torch
.
load
(
self
.
filename
)[
'model_state_dict'
])
def
collate_molgraphs
(
data
):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
smiles : list
List of smiles
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint binary mask, indicating the
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
assert
len
(
data
[
0
])
in
[
3
,
4
],
\
'Expect the tuple to be of length 3 or 4, got {:d}'
.
format
(
len
(
data
[
0
]))
if
len
(
data
[
0
])
==
3
:
smiles
,
graphs
,
labels
=
map
(
list
,
zip
(
*
data
))
masks
=
None
else
:
smiles
,
graphs
,
labels
,
masks
=
map
(
list
,
zip
(
*
data
))
bg
=
dgl
.
batch
(
graphs
)
bg
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
bg
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
labels
=
torch
.
stack
(
labels
,
dim
=
0
)
if
masks
is
None
:
masks
=
torch
.
ones
(
labels
.
shape
)
else
:
masks
=
torch
.
stack
(
masks
,
dim
=
0
)
return
smiles
,
bg
,
labels
,
masks
def
load_dataset_for_classification
(
args
):
"""Load dataset for classification tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
dataset
The whole dataset.
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert
args
[
'dataset'
]
in
[
'Tox21'
]
if
args
[
'dataset'
]
==
'Tox21'
:
from
dgl.data.chem
import
Tox21
dataset
=
Tox21
(
smiles_to_bigraph
,
args
[
'atom_featurizer'
])
train_set
,
val_set
,
test_set
=
RandomSplitter
.
train_val_test_split
(
dataset
,
frac_train
=
args
[
'frac_train'
],
frac_val
=
args
[
'frac_val'
],
frac_test
=
args
[
'frac_test'
],
random_state
=
args
[
'random_seed'
])
return
dataset
,
train_set
,
val_set
,
test_set
def
load_dataset_for_regression
(
args
):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert
args
[
'dataset'
]
in
[
'Alchemy'
,
'Aromaticity'
]
if
args
[
'dataset'
]
==
'Alchemy'
:
from
dgl.data.chem
import
TencentAlchemyDataset
train_set
=
TencentAlchemyDataset
(
mode
=
'dev'
)
val_set
=
TencentAlchemyDataset
(
mode
=
'valid'
)
test_set
=
None
if
args
[
'dataset'
]
==
'Aromaticity'
:
from
dgl.data.chem
import
PubChemBioAssayAromaticity
dataset
=
PubChemBioAssayAromaticity
(
smiles_to_bigraph
,
args
[
'atom_featurizer'
],
args
[
'bond_featurizer'
])
train_set
,
val_set
,
test_set
=
RandomSplitter
.
train_val_test_split
(
dataset
,
frac_train
=
args
[
'frac_train'
],
frac_val
=
args
[
'frac_val'
],
frac_test
=
args
[
'frac_test'
],
random_state
=
args
[
'random_seed'
])
return
train_set
,
val_set
,
test_set
def
load_model
(
args
):
if
args
[
'model'
]
==
'GCN'
:
model
=
model_zoo
.
chem
.
GCNClassifier
(
in_feats
=
args
[
'in_feats'
],
gcn_hidden_feats
=
args
[
'gcn_hidden_feats'
],
classifier_hidden_feats
=
args
[
'classifier_hidden_feats'
],
n_tasks
=
args
[
'n_tasks'
])
if
args
[
'model'
]
==
'GAT'
:
model
=
model_zoo
.
chem
.
GATClassifier
(
in_feats
=
args
[
'in_feats'
],
gat_hidden_feats
=
args
[
'gat_hidden_feats'
],
num_heads
=
args
[
'num_heads'
],
classifier_hidden_feats
=
args
[
'classifier_hidden_feats'
],
n_tasks
=
args
[
'n_tasks'
])
if
args
[
'model'
]
==
'MPNN'
:
model
=
model_zoo
.
chem
.
MPNNModel
(
node_input_dim
=
args
[
'node_in_feats'
],
edge_input_dim
=
args
[
'edge_in_feats'
],
output_dim
=
args
[
'output_dim'
])
if
args
[
'model'
]
==
'SCHNET'
:
model
=
model_zoo
.
chem
.
SchNet
(
norm
=
args
[
'norm'
],
output_dim
=
args
[
'output_dim'
])
if
args
[
'model'
]
==
'MGCN'
:
model
=
model_zoo
.
chem
.
MGCNModel
(
norm
=
args
[
'norm'
],
output_dim
=
args
[
'output_dim'
])
if
args
[
'model'
]
==
'AttentiveFP'
:
model
=
model_zoo
.
chem
.
AttentiveFP
(
node_feat_size
=
args
[
'node_feat_size'
],
edge_feat_size
=
args
[
'edge_feat_size'
],
num_layers
=
args
[
'num_layers'
],
num_timesteps
=
args
[
'num_timesteps'
],
graph_feat_size
=
args
[
'graph_feat_size'
],
output_size
=
args
[
'output_size'
],
dropout
=
args
[
'dropout'
])
return
model
examples/pytorch/model_zoo/chem/requirements.txt
deleted
100644 → 0
View file @
94c67203
scikit-learn==0.21.2
pandas==0.25.1
requests==2.22.0
python/dgl/contrib/deprecation.py
deleted
100644 → 0
View file @
94c67203
"""Decorator for deprecation message.
This is used in migrating the chem related code to DGL-LifeSci.
Todo(Mufei): remove it in v0.5.
The code is adapted from
https://stackoverflow.com/questions/2536307/
decorators-in-the-python-standard-lib-deprecated-specifically/48632082#48632082.
"""
import
warnings
def
deprecated
(
message
,
mode
=
'func'
):
"""Print formatted deprecation message.
Parameters
----------
message : str
mode : str
'func' for function and 'class' for class.
Return
------
callable
"""
assert
mode
in
[
'func'
,
'class'
]
def
deprecated_decorator
(
func
):
def
deprecated_func
(
*
args
,
**
kwargs
):
if
mode
==
'func'
:
warnings
.
warn
(
"{} is deprecated and will be removed from dgl in v0.5. {}"
.
format
(
func
.
__name__
,
message
),
category
=
DeprecationWarning
,
stacklevel
=
2
)
else
:
warnings
.
warn
(
"The class is deprecated and "
"will be removed from dgl in v0.5. {}"
.
format
(
message
),
category
=
DeprecationWarning
,
stacklevel
=
2
)
warnings
.
simplefilter
(
'default'
,
DeprecationWarning
)
return
func
(
*
args
,
**
kwargs
)
return
deprecated_func
return
deprecated_decorator
python/dgl/data/chem/README.md
deleted
100644 → 0
View file @
94c67203
# Customize Dataset
Generally we follow the practise of PyTorch.
A Dataset class should implement
`__getitem__(self, index)`
and
`__len__(self)`
method
```
python
class
CustomDataset
:
def
__init__
(
self
):
# Initialize Dataset and preprocess data
def
__getitem__
(
self
,
index
):
# Return the corresponding DGLGraph/label needed for training/evaluation based on index
return
self
.
graphs
[
index
],
self
.
labels
[
index
]
def
__len__
(
self
):
return
len
(
self
.
graphs
)
```
DGL supports various backends such as MXNet and PyTorch, therefore we want our dataset to be also backend agnostic.
We prefer user using numpy array in the dataset, and not including any operator/tensor from the specific backend.
If you want to convert the numpy array to the corresponding tensor, you can use the following code
```
python
import
dgl.backend
as
F
# g is a DGLGraph, h is a numpy array
g
.
ndata
[
'h'
]
=
F
.
zerocopy_from_numpy
(
h
)
# Now g.ndata is a PyTorch Tensor or a MXNet NDArray based on backend used
```
If your dataset is in
`.csv`
format, you may use
[
`CSVDataset`
](
https://github.com/dmlc/dgl/blob/master/python/dgl/data/chem/csv_dataset.py
)
.
python/dgl/data/chem/__init__.py
deleted
100644 → 0
View file @
94c67203
from
.datasets
import
*
from
.utils
import
*
python/dgl/data/chem/datasets/__init__.py
deleted
100644 → 0
View file @
94c67203
from
.csv_dataset
import
MoleculeCSVDataset
from
.tox21
import
Tox21
from
.alchemy
import
TencentAlchemyDataset
from
.pubchem_aromaticity
import
PubChemBioAssayAromaticity
from
.pdbbind
import
PDBBind
python/dgl/data/chem/datasets/alchemy.py
deleted
100644 → 0
View file @
94c67203
# -*- coding:utf-8 -*-
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import
numpy
as
np
import
os
import
os.path
as
osp
import
pathlib
import
zipfile
from
collections
import
defaultdict
from
..utils
import
mol_to_complete_graph
,
atom_type_one_hot
,
\
atom_hybridization_one_hot
,
atom_is_aromatic
from
...utils
import
download
,
get_download_dir
,
_get_dgl_url
,
save_graphs
,
load_graphs
from
....utils
import
retry_method_with_fix
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
try
:
import
pandas
as
pd
from
rdkit
import
Chem
from
rdkit.Chem
import
ChemicalFeatures
from
rdkit
import
RDConfig
except
ImportError
:
pass
def
alchemy_nodes
(
mol
):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict
=
defaultdict
(
list
)
is_donor
=
defaultdict
(
int
)
is_acceptor
=
defaultdict
(
int
)
fdef_name
=
osp
.
join
(
RDConfig
.
RDDataDir
,
'BaseFeatures.fdef'
)
mol_featurizer
=
ChemicalFeatures
.
BuildFeatureFactory
(
fdef_name
)
mol_feats
=
mol_featurizer
.
GetFeaturesForMol
(
mol
)
mol_conformers
=
mol
.
GetConformers
()
assert
len
(
mol_conformers
)
==
1
for
i
in
range
(
len
(
mol_feats
)):
if
mol_feats
[
i
].
GetFamily
()
==
'Donor'
:
node_list
=
mol_feats
[
i
].
GetAtomIds
()
for
u
in
node_list
:
is_donor
[
u
]
=
1
elif
mol_feats
[
i
].
GetFamily
()
==
'Acceptor'
:
node_list
=
mol_feats
[
i
].
GetAtomIds
()
for
u
in
node_list
:
is_acceptor
[
u
]
=
1
num_atoms
=
mol
.
GetNumAtoms
()
for
u
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
u
)
atom_type
=
atom
.
GetAtomicNum
()
num_h
=
atom
.
GetTotalNumHs
()
atom_feats_dict
[
'node_type'
].
append
(
atom_type
)
h_u
=
[]
h_u
+=
atom_type_one_hot
(
atom
,
[
'H'
,
'C'
,
'N'
,
'O'
,
'F'
,
'S'
,
'Cl'
])
h_u
.
append
(
atom_type
)
h_u
.
append
(
is_acceptor
[
u
])
h_u
.
append
(
is_donor
[
u
])
h_u
+=
atom_is_aromatic
(
atom
)
h_u
+=
atom_hybridization_one_hot
(
atom
,
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
])
h_u
.
append
(
num_h
)
atom_feats_dict
[
'n_feat'
].
append
(
F
.
tensor
(
np
.
asarray
(
h_u
,
dtype
=
np
.
float32
)))
atom_feats_dict
[
'n_feat'
]
=
F
.
stack
(
atom_feats_dict
[
'n_feat'
],
dim
=
0
)
atom_feats_dict
[
'node_type'
]
=
F
.
tensor
(
np
.
asarray
(
atom_feats_dict
[
'node_type'
],
dtype
=
np
.
int64
))
return
atom_feats_dict
def
alchemy_edges
(
mol
,
self_loop
=
False
):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict
=
defaultdict
(
list
)
mol_conformers
=
mol
.
GetConformers
()
assert
len
(
mol_conformers
)
==
1
geom
=
mol_conformers
[
0
].
GetPositions
()
num_atoms
=
mol
.
GetNumAtoms
()
for
u
in
range
(
num_atoms
):
for
v
in
range
(
num_atoms
):
if
u
==
v
and
not
self_loop
:
continue
e_uv
=
mol
.
GetBondBetweenAtoms
(
u
,
v
)
if
e_uv
is
None
:
bond_type
=
None
else
:
bond_type
=
e_uv
.
GetBondType
()
bond_feats_dict
[
'e_feat'
].
append
([
float
(
bond_type
==
x
)
for
x
in
(
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
,
None
)
])
bond_feats_dict
[
'distance'
].
append
(
np
.
linalg
.
norm
(
geom
[
u
]
-
geom
[
v
]))
bond_feats_dict
[
'e_feat'
]
=
F
.
tensor
(
np
.
asarray
(
bond_feats_dict
[
'e_feat'
],
dtype
=
np
.
float32
))
bond_feats_dict
[
'distance'
]
=
F
.
tensor
(
np
.
asarray
(
bond_feats_dict
[
'distance'
],
dtype
=
np
.
float32
)).
reshape
(
-
1
,
1
)
return
bond_feats_dict
class
TencentAlchemyDataset
(
object
):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@
deprecated
(
'Import TencentAlchemyDataset from dgllife.data.alchemy instead.'
,
'class'
)
def
__init__
(
self
,
mode
=
'dev'
,
mol_to_graph
=
mol_to_complete_graph
,
node_featurizer
=
alchemy_nodes
,
edge_featurizer
=
alchemy_edges
,
load
=
True
):
if
mode
==
'test'
:
raise
ValueError
(
'The test mode is not supported before '
'the Alchemy contest finishes.'
)
assert
mode
in
[
'dev'
,
'valid'
,
'test'
],
\
'Expect mode to be dev, valid or test, got {}.'
.
format
(
mode
)
self
.
mode
=
mode
# Construct DGLGraphs from raw data or use the preprocessed data
self
.
load
=
load
file_dir
=
osp
.
join
(
get_download_dir
(),
'Alchemy_data'
)
if
load
:
file_name
=
"%s_processed_dgl"
%
(
mode
)
else
:
file_name
=
"%s_single_sdf"
%
(
mode
)
self
.
_file_dir
=
file_dir
self
.
file_dir
=
pathlib
.
Path
(
file_dir
,
file_name
)
self
.
_url
=
'dataset/alchemy/'
self
.
zip_file_path
=
pathlib
.
Path
(
file_dir
,
file_name
+
'.zip'
)
self
.
_file_name
=
file_name
self
.
_load
(
mol_to_graph
,
node_featurizer
,
edge_featurizer
)
def
_download_and_extract
(
self
):
download
(
_get_dgl_url
(
self
.
_url
+
self
.
_file_name
+
'.zip'
),
path
=
str
(
self
.
zip_file_path
))
if
not
os
.
path
.
exists
(
str
(
self
.
file_dir
)):
archive
=
zipfile
.
ZipFile
(
self
.
zip_file_path
)
archive
.
extractall
(
self
.
_file_dir
)
archive
.
close
()
@
retry_method_with_fix
(
_download_and_extract
)
def
_load
(
self
,
mol_to_graph
,
node_featurizer
,
edge_featurizer
):
if
self
.
load
:
self
.
graphs
,
label_dict
=
load_graphs
(
osp
.
join
(
self
.
file_dir
,
"%s_graphs.bin"
%
self
.
mode
))
self
.
labels
=
label_dict
[
'labels'
]
with
open
(
osp
.
join
(
self
.
file_dir
,
"%s_smiles.txt"
%
self
.
mode
),
'r'
)
as
f
:
smiles_
=
f
.
readlines
()
self
.
smiles
=
[
s
.
strip
()
for
s
in
smiles_
]
else
:
print
(
'Start preprocessing dataset...'
)
target_file
=
pathlib
.
Path
(
self
.
file_dir
,
"%s_target.csv"
%
self
.
mode
)
self
.
target
=
pd
.
read_csv
(
target_file
,
index_col
=
0
,
usecols
=
[
'gdb_idx'
,]
+
[
'property_%d'
%
x
for
x
in
range
(
12
)])
self
.
target
=
self
.
target
[[
'property_%d'
%
x
for
x
in
range
(
12
)]]
self
.
graphs
,
self
.
labels
,
self
.
smiles
=
[],
[],
[]
supp
=
Chem
.
SDMolSupplier
(
osp
.
join
(
self
.
file_dir
,
self
.
mode
+
".sdf"
))
cnt
=
0
dataset_size
=
len
(
self
.
target
)
for
mol
,
label
in
zip
(
supp
,
self
.
target
.
iterrows
()):
cnt
+=
1
print
(
'Processing molecule {:d}/{:d}'
.
format
(
cnt
,
dataset_size
))
graph
=
mol_to_graph
(
mol
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
)
smiles
=
Chem
.
MolToSmiles
(
mol
)
self
.
smiles
.
append
(
smiles
)
self
.
graphs
.
append
(
graph
)
label
=
F
.
tensor
(
np
.
asarray
(
label
[
1
].
tolist
(),
dtype
=
np
.
float32
))
self
.
labels
.
append
(
label
)
save_graphs
(
osp
.
join
(
self
.
file_dir
,
"%s_graphs.bin"
%
self
.
mode
),
self
.
graphs
,
labels
=
{
'labels'
:
F
.
stack
(
self
.
labels
,
dim
=
0
)})
with
open
(
osp
.
join
(
self
.
file_dir
,
"%s_smiles.txt"
%
self
.
mode
),
'w'
)
as
f
:
for
s
in
self
.
smiles
:
f
.
write
(
s
+
'
\n
'
)
self
.
set_mean_and_std
()
print
(
len
(
self
.
graphs
),
"loaded!"
)
def
__getitem__
(
self
,
item
):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
"""
return
self
.
smiles
[
item
],
self
.
graphs
[
item
],
self
.
labels
[
item
]
def
__len__
(
self
):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return
len
(
self
.
graphs
)
def
set_mean_and_std
(
self
,
mean
=
None
,
std
=
None
):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
"""
labels
=
np
.
asarray
([
i
.
numpy
()
for
i
in
self
.
labels
])
if
mean
is
None
:
mean
=
np
.
mean
(
labels
,
axis
=
0
)
if
std
is
None
:
std
=
np
.
std
(
labels
,
axis
=
0
)
self
.
mean
=
mean
self
.
std
=
std
python/dgl/data/chem/datasets/csv_dataset.py
deleted
100644 → 0
View file @
94c67203
from
__future__
import
absolute_import
import
numpy
as
np
import
os
import
sys
from
...utils
import
save_graphs
,
load_graphs
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
class
MoleculeCSVDataset
(
object
):
"""MoleculeCSVDataset
This is a general class for loading molecular data from pandas.DataFrame.
In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@
deprecated
(
'Import MoleculeCSVDataset from dgllife.data instead.'
,
'class'
)
def
__init__
(
self
,
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
smiles_column
,
cache_file_path
,
task_names
=
None
,
load
=
True
):
if
'rdkit'
not
in
sys
.
modules
:
from
....base
import
dgl_warning
dgl_warning
(
"Please install RDKit (Recommended Version is 2018.09.3)"
)
self
.
df
=
df
self
.
smiles
=
self
.
df
[
smiles_column
].
tolist
()
if
task_names
is
None
:
self
.
task_names
=
self
.
df
.
columns
.
drop
([
smiles_column
]).
tolist
()
else
:
self
.
task_names
=
task_names
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
cache_file_path
=
cache_file_path
self
.
_pre_process
(
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
)
def
_pre_process
(
self
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
if
os
.
path
.
exists
(
self
.
cache_file_path
)
and
load
:
# DGLGraphs have been constructed before, reload them
print
(
'Loading previously saved dgl graphs...'
)
self
.
graphs
,
label_dict
=
load_graphs
(
self
.
cache_file_path
)
self
.
labels
=
label_dict
[
'labels'
]
self
.
mask
=
label_dict
[
'mask'
]
else
:
print
(
'Processing dgl graphs from scratch...'
)
self
.
graphs
=
[]
for
i
,
s
in
enumerate
(
self
.
smiles
):
print
(
'Processing molecule {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
self
.
graphs
.
append
(
smiles_to_graph
(
s
,
node_featurizer
=
node_featurizer
,
edge_featurizer
=
edge_featurizer
))
_label_values
=
self
.
df
[
self
.
task_names
].
values
# np.nan_to_num will also turn inf into a very large number
self
.
labels
=
F
.
zerocopy_from_numpy
(
np
.
nan_to_num
(
_label_values
).
astype
(
np
.
float32
))
self
.
mask
=
F
.
zerocopy_from_numpy
((
~
np
.
isnan
(
_label_values
)).
astype
(
np
.
float32
))
save_graphs
(
self
.
cache_file_path
,
self
.
graphs
,
labels
=
{
'labels'
:
self
.
labels
,
'mask'
:
self
.
mask
})
def
__getitem__
(
self
,
item
):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Binary masks indicating the existence of labels for all tasks
"""
return
self
.
smiles
[
item
],
self
.
graphs
[
item
],
self
.
labels
[
item
],
self
.
mask
[
item
]
def
__len__
(
self
):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return
len
(
self
.
smiles
)
python/dgl/data/chem/datasets/pdbbind.py
deleted
100644 → 0
View file @
94c67203
"""PDBBind dataset processed by MoleculeNet."""
import
numpy
as
np
import
os
import
pandas
as
pd
from
..utils
import
multiprocess_load_molecules
,
ACNN_graph_construction_and_featurization
from
...utils
import
get_download_dir
,
download
,
_get_dgl_url
,
extract_archive
from
....utils
import
retry_method_with_fix
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
class
PDBBind
(
object
):
"""PDBbind dataset processed by MoleculeNet.
The description below is mainly based on
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__.
The PDBBind database consists of experimentally measured binding affinities for
bio-molecular complexes `[2] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15163179%5Buid%5D>`__,
`[3] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15943484%5Buid%5D>`__. It provides detailed
3D Cartesian coordinates of both ligands and their target proteins derived from experimental
(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the
protein-ligand binding geometry. The authors of
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__ use the
"refined" and "core" subsets of the database
`[4] <https://www.ncbi.nlm.nih.gov/pubmed/?term=25301850%5Buid%5D>`__, more carefully
processed for data artifacts, as additional benchmarking targets.
References:
* [1] MoleculeNet: a benchmark for molecular machine learning
* [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures
* [3] The PDBbind database: methodologies and updates
* [4] PDB-wide collection of binding data: current status of the PDBbind database
Parameters
----------
subset : str
In MoleculeNet, we can use either the "refined" subset or the "core" subset. We can
retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios. Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
@
deprecated
(
'Import PDBBind from dgllife.data instead.'
,
'class'
)
def
__init__
(
self
,
subset
,
load_binding_pocket
=
True
,
add_hydrogens
=
False
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
,
construct_graph_and_featurize
=
ACNN_graph_construction_and_featurization
,
zero_padding
=
True
,
num_processes
=
64
):
self
.
task_names
=
[
'-logKd/Ki'
]
self
.
n_tasks
=
len
(
self
.
task_names
)
self
.
_url
=
'dataset/pdbbind_v2015.tar.gz'
root_dir_path
=
get_download_dir
()
data_path
=
root_dir_path
+
'/pdbbind_v2015.tar.gz'
extracted_data_path
=
root_dir_path
+
'/pdbbind_v2015'
if
subset
==
'core'
:
index_label_file
=
extracted_data_path
+
'/v2015/INDEX_core_data.2013'
elif
subset
==
'refined'
:
index_label_file
=
extracted_data_path
+
'/v2015/INDEX_refined_data.2015'
else
:
raise
ValueError
(
'Expect the subset_choice to be either '
'core or refined, got {}'
.
format
(
subset
))
self
.
_data_path
=
data_path
self
.
_extracted_data_path
=
extracted_data_path
self
.
_preprocess
(
extracted_data_path
,
index_label_file
,
load_binding_pocket
,
add_hydrogens
,
sanitize
,
calc_charges
,
remove_hs
,
use_conformation
,
construct_graph_and_featurize
,
zero_padding
,
num_processes
)
def
_filter_out_invalid
(
self
,
ligands_loaded
,
proteins_loaded
,
use_conformation
):
"""Filter out invalid ligand-protein pairs.
Parameters
----------
ligands_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
proteins_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
use_conformation : bool
Whether we need conformation information (atom coordinates) and filter out molecules
without valid conformation.
"""
num_pairs
=
len
(
proteins_loaded
)
self
.
indices
,
self
.
ligand_mols
,
self
.
protein_mols
=
[],
[],
[]
if
use_conformation
:
self
.
ligand_coordinates
,
self
.
protein_coordinates
=
[],
[]
else
:
# Use None for placeholders.
self
.
ligand_coordinates
=
[
None
for
_
in
range
(
num_pairs
)]
self
.
protein_coordinates
=
[
None
for
_
in
range
(
num_pairs
)]
for
i
in
range
(
num_pairs
):
ligand_mol
,
ligand_coordinates
=
ligands_loaded
[
i
]
protein_mol
,
protein_coordinates
=
proteins_loaded
[
i
]
if
(
not
use_conformation
)
and
all
(
v
is
not
None
for
v
in
[
protein_mol
,
ligand_mol
]):
self
.
indices
.
append
(
i
)
self
.
ligand_mols
.
append
(
ligand_mol
)
self
.
protein_mols
.
append
(
protein_mol
)
elif
all
(
v
is
not
None
for
v
in
[
protein_mol
,
protein_coordinates
,
ligand_mol
,
ligand_coordinates
]):
self
.
indices
.
append
(
i
)
self
.
ligand_mols
.
append
(
ligand_mol
)
self
.
ligand_coordinates
.
append
(
ligand_coordinates
)
self
.
protein_mols
.
append
(
protein_mol
)
self
.
protein_coordinates
.
append
(
protein_coordinates
)
def
_download_and_extract
(
self
):
download
(
_get_dgl_url
(
self
.
_url
),
path
=
self
.
_data_path
)
extract_archive
(
self
.
_data_path
,
self
.
_extracted_data_path
)
@
retry_method_with_fix
(
_download_and_extract
)
def
_preprocess
(
self
,
root_path
,
index_label_file
,
load_binding_pocket
,
add_hydrogens
,
sanitize
,
calc_charges
,
remove_hs
,
use_conformation
,
construct_graph_and_featurize
,
zero_padding
,
num_processes
):
"""Preprocess the dataset.
The pre-processing proceeds as follows:
1. Load the dataset
2. Clean the dataset and filter out invalid pairs
3. Construct graphs
4. Prepare node and edge features
Parameters
----------
root_path : str
Root path for molecule files.
index_label_file : str
Path to the index file for the dataset.
load_binding_pocket : bool
Whether to load binding pockets or full proteins.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system.
"""
contents
=
[]
with
open
(
index_label_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
if
line
[
0
]
!=
"#"
:
splitted_elements
=
line
.
split
()
if
len
(
splitted_elements
)
==
8
:
# Ignore "//"
contents
.
append
(
splitted_elements
[:
5
]
+
splitted_elements
[
6
:])
else
:
print
(
'Incorrect data format.'
)
print
(
splitted_elements
)
self
.
df
=
pd
.
DataFrame
(
contents
,
columns
=
(
'PDB_code'
,
'resolution'
,
'release_year'
,
'-logKd/Ki'
,
'Kd/Ki'
,
'reference'
,
'ligand_name'
))
pdbs
=
self
.
df
[
'PDB_code'
].
tolist
()
self
.
ligand_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_ligand.sdf'
.
format
(
pdb
))
for
pdb
in
pdbs
]
if
load_binding_pocket
:
self
.
protein_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_pocket.pdb'
.
format
(
pdb
))
for
pdb
in
pdbs
]
else
:
self
.
protein_files
=
[
os
.
path
.
join
(
root_path
,
'v2015'
,
pdb
,
'{}_protein.pdb'
.
format
(
pdb
))
for
pdb
in
pdbs
]
num_processes
=
min
(
num_processes
,
len
(
pdbs
))
print
(
'Loading ligands...'
)
ligands_loaded
=
multiprocess_load_molecules
(
self
.
ligand_files
,
add_hydrogens
=
add_hydrogens
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
,
num_processes
=
num_processes
)
print
(
'Loading proteins...'
)
proteins_loaded
=
multiprocess_load_molecules
(
self
.
protein_files
,
add_hydrogens
=
add_hydrogens
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
,
num_processes
=
num_processes
)
self
.
_filter_out_invalid
(
ligands_loaded
,
proteins_loaded
,
use_conformation
)
self
.
df
=
self
.
df
.
iloc
[
self
.
indices
]
self
.
labels
=
F
.
zerocopy_from_numpy
(
self
.
df
[
self
.
task_names
].
values
.
astype
(
np
.
float32
))
print
(
'Finished cleaning the dataset, '
'got {:d}/{:d} valid pairs'
.
format
(
len
(
self
),
len
(
pdbs
)))
# Prepare zero padding
if
zero_padding
:
max_num_ligand_atoms
=
0
max_num_protein_atoms
=
0
for
i
in
range
(
len
(
self
)):
max_num_ligand_atoms
=
max
(
max_num_ligand_atoms
,
self
.
ligand_mols
[
i
].
GetNumAtoms
())
max_num_protein_atoms
=
max
(
max_num_protein_atoms
,
self
.
protein_mols
[
i
].
GetNumAtoms
())
else
:
max_num_ligand_atoms
=
None
max_num_protein_atoms
=
None
print
(
'Start constructing graphs and featurizing them.'
)
self
.
graphs
=
[]
for
i
in
range
(
len
(
self
)):
print
(
'Constructing and featurizing datapoint {:d}/{:d}'
.
format
(
i
+
1
,
len
(
self
)))
self
.
graphs
.
append
(
construct_graph_and_featurize
(
self
.
ligand_mols
[
i
],
self
.
protein_mols
[
i
],
self
.
ligand_coordinates
[
i
],
self
.
protein_coordinates
[
i
],
max_num_ligand_atoms
,
max_num_protein_atoms
))
def
__len__
(
self
):
"""Get the size of the dataset.
Returns
-------
int
Number of valid ligand-protein pairs in the dataset.
"""
return
len
(
self
.
indices
)
def
__getitem__
(
self
,
item
):
"""Get the datapoint associated with the index.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
int
Index for the datapoint.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the ligand molecule.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the protein molecule.
DGLHeteroGraph
Pre-processed DGLHeteroGraph with features extracted.
Float32 tensor
Label for the datapoint.
"""
return
item
,
self
.
ligand_mols
[
item
],
self
.
protein_mols
[
item
],
\
self
.
graphs
[
item
],
self
.
labels
[
item
]
python/dgl/data/chem/datasets/pubchem_aromaticity.py
deleted
100644 → 0
View file @
94c67203
import
pandas
as
pd
import
sys
from
.csv_dataset
import
MoleculeCSVDataset
from
..utils
import
smiles_to_bigraph
from
...utils
import
get_download_dir
,
download
,
_get_dgl_url
from
....utils
import
retry_method_with_fix
from
....base
import
dgl_warning
from
....contrib.deprecation
import
deprecated
class
PubChemBioAssayAromaticity
(
MoleculeCSVDataset
):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
"""
@
deprecated
(
'Import PubChemBioAssayAromaticity from dgllife.data instead.'
,
'class'
)
def
__init__
(
self
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
load
=
True
):
if
'pandas'
not
in
sys
.
modules
:
dgl_warning
(
"Please install pandas"
)
self
.
_url
=
'dataset/pubchem_bioassay_aromaticity.csv'
data_path
=
get_download_dir
()
+
'/pubchem_bioassay_aromaticity.csv'
self
.
_data_path
=
data_path
self
.
_load
(
data_path
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
)
def
_download
(
self
):
download
(
_get_dgl_url
(
self
.
_url
),
path
=
self
.
_data_path
)
@
retry_method_with_fix
(
_download
)
def
_load
(
self
,
data_path
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
):
df
=
pd
.
read_csv
(
data_path
)
super
(
PubChemBioAssayAromaticity
,
self
).
__init__
(
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
"cano_smiles"
,
"pubchem_aromaticity_dglgraph.bin"
,
load
=
load
)
python/dgl/data/chem/datasets/tox21.py
deleted
100644 → 0
View file @
94c67203
import
sys
from
.csv_dataset
import
MoleculeCSVDataset
from
..utils
import
smiles_to_bigraph
from
...utils
import
get_download_dir
,
download
,
_get_dgl_url
from
....
import
backend
as
F
from
....utils
import
retry_method_with_fix
from
....base
import
dgl_warning
from
....contrib.deprecation
import
deprecated
try
:
import
pandas
as
pd
except
ImportError
:
pass
class
Tox21
(
MoleculeCSVDataset
):
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@
deprecated
(
'Import Tox21 from dgllife.data instead.'
,
'class'
)
def
__init__
(
self
,
smiles_to_graph
=
smiles_to_bigraph
,
node_featurizer
=
None
,
edge_featurizer
=
None
,
load
=
True
):
if
'pandas'
not
in
sys
.
modules
:
dgl_warning
(
"Please install pandas"
)
self
.
_url
=
'dataset/tox21.csv.gz'
data_path
=
get_download_dir
()
+
'/tox21.csv.gz'
self
.
_data_path
=
data_path
self
.
_load
(
data_path
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
)
def
_download
(
self
):
download
(
_get_dgl_url
(
self
.
_url
),
path
=
self
.
_data_path
)
@
retry_method_with_fix
(
_download
)
def
_load
(
self
,
data_path
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
load
):
df
=
pd
.
read_csv
(
data_path
)
self
.
id
=
df
[
'mol_id'
]
df
=
df
.
drop
(
columns
=
[
'mol_id'
])
super
(
Tox21
,
self
).
__init__
(
df
,
smiles_to_graph
,
node_featurizer
,
edge_featurizer
,
"smiles"
,
"tox21_dglgraph.bin"
,
load
=
load
)
self
.
_weight_balancing
()
def
_weight_balancing
(
self
):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos
=
F
.
sum
(
self
.
labels
,
dim
=
0
)
num_indices
=
F
.
sum
(
self
.
mask
,
dim
=
0
)
self
.
_task_pos_weights
=
(
num_indices
-
num_pos
)
/
num_pos
@
property
def
task_pos_weights
(
self
):
"""Get weights for positive samples on each task
Returns
-------
numpy.ndarray
numpy array gives the weight of positive samples on all tasks
"""
return
self
.
_task_pos_weights
python/dgl/data/chem/utils/__init__.py
deleted
100644 → 0
View file @
94c67203
from
.splitters
import
*
from
.featurizers
import
*
from
.mol_to_graph
import
*
from
.complex_to_graph
import
*
from
.rdkit_utils
import
*
python/dgl/data/chem/utils/complex_to_graph.py
deleted
100644 → 0
View file @
94c67203
"""Convert complexes into DGLHeteroGraphs"""
import
numpy
as
np
from
..utils
import
k_nearest_neighbors
from
....
import
graph
,
bipartite
,
hetero_from_relations
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
__all__
=
[
'ACNN_graph_construction_and_featurization'
]
def
filter_out_hydrogens
(
mol
):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left
=
[]
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
atomic_num
=
atom
.
GetAtomicNum
()
# Hydrogen atoms have an atomic number of 1.
if
atomic_num
!=
1
:
indices_left
.
append
(
i
)
return
indices_left
def
get_atomic_numbers
(
mol
,
indices
):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers
=
[]
for
i
in
indices
:
atom
=
mol
.
GetAtomWithIdx
(
i
)
atomic_numbers
.
append
(
atom
.
GetAtomicNum
())
return
atomic_numbers
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
ACNN_graph_construction_and_featurization
(
ligand_mol
,
protein_mol
,
ligand_coordinates
,
protein_coordinates
,
max_num_ligand_atoms
=
None
,
max_num_protein_atoms
=
None
,
neighbor_cutoff
=
12.
,
max_num_neighbors
=
12
,
strip_hydrogens
=
False
):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert
ligand_coordinates
is
not
None
,
'Expect ligand_coordinates to be provided.'
assert
protein_coordinates
is
not
None
,
'Expect protein_coordinates to be provided.'
if
max_num_ligand_atoms
is
not
None
:
assert
max_num_ligand_atoms
>=
ligand_mol
.
GetNumAtoms
(),
\
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms()'
if
max_num_protein_atoms
is
not
None
:
assert
max_num_protein_atoms
>=
protein_mol
.
GetNumAtoms
(),
\
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms()'
if
strip_hydrogens
:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left
=
filter_out_hydrogens
(
ligand_mol
)
protein_atom_indices_left
=
filter_out_hydrogens
(
protein_mol
)
ligand_coordinates
=
ligand_coordinates
.
take
(
ligand_atom_indices_left
,
axis
=
0
)
protein_coordinates
=
protein_coordinates
.
take
(
protein_atom_indices_left
,
axis
=
0
)
else
:
ligand_atom_indices_left
=
list
(
range
(
ligand_mol
.
GetNumAtoms
()))
protein_atom_indices_left
=
list
(
range
(
protein_mol
.
GetNumAtoms
()))
# Compute number of nodes for each type
if
max_num_ligand_atoms
is
None
:
num_ligand_atoms
=
len
(
ligand_atom_indices_left
)
else
:
num_ligand_atoms
=
max_num_ligand_atoms
if
max_num_protein_atoms
is
None
:
num_protein_atoms
=
len
(
protein_atom_indices_left
)
else
:
num_protein_atoms
=
max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs
,
ligand_dsts
,
ligand_dists
=
k_nearest_neighbors
(
ligand_coordinates
,
neighbor_cutoff
,
max_num_neighbors
)
ligand_graph
=
graph
((
ligand_srcs
,
ligand_dsts
),
'ligand_atom'
,
'ligand'
,
num_ligand_atoms
)
ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
np
.
asarray
(
ligand_dists
,
dtype
=
np
.
float32
)),
(
-
1
,
1
))
# Construct graph for atoms in the protein
protein_srcs
,
protein_dsts
,
protein_dists
=
k_nearest_neighbors
(
protein_coordinates
,
neighbor_cutoff
,
max_num_neighbors
)
protein_graph
=
graph
((
protein_srcs
,
protein_dsts
),
'protein_atom'
,
'protein'
,
num_protein_atoms
)
protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
np
.
asarray
(
protein_dists
,
dtype
=
np
.
float32
)),
(
-
1
,
1
))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs
,
complex_dsts
,
complex_dists
=
k_nearest_neighbors
(
np
.
concatenate
([
ligand_coordinates
,
protein_coordinates
]),
neighbor_cutoff
,
max_num_neighbors
)
complex_srcs
=
np
.
asarray
(
complex_srcs
)
complex_dsts
=
np
.
asarray
(
complex_dsts
)
complex_dists
=
np
.
asarray
(
complex_dists
)
offset
=
num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices
=
np
.
intersect1d
(
(
complex_srcs
<
offset
).
nonzero
()[
0
],
(
complex_dsts
<
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
inter_ligand_graph
=
graph
(
(
complex_srcs
[
inter_ligand_indices
].
tolist
(),
complex_dsts
[
inter_ligand_indices
].
tolist
()),
'ligand_atom'
,
'complex'
,
num_ligand_atoms
)
inter_ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
inter_ligand_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices
=
np
.
intersect1d
(
(
complex_srcs
>=
offset
).
nonzero
()[
0
],
(
complex_dsts
>=
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
inter_protein_graph
=
graph
(
((
complex_srcs
[
inter_protein_indices
]
-
offset
).
tolist
(),
(
complex_dsts
[
inter_protein_indices
]
-
offset
).
tolist
()),
'protein_atom'
,
'complex'
,
num_protein_atoms
)
inter_protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
inter_protein_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices
=
np
.
intersect1d
(
(
complex_srcs
<
offset
).
nonzero
()[
0
],
(
complex_dsts
>=
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
ligand_protein_graph
=
bipartite
(
(
complex_srcs
[
ligand_protein_indices
].
tolist
(),
(
complex_dsts
[
ligand_protein_indices
]
-
offset
).
tolist
()),
'ligand_atom'
,
'complex'
,
'protein_atom'
,
(
num_ligand_atoms
,
num_protein_atoms
))
ligand_protein_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
ligand_protein_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices
=
np
.
intersect1d
(
(
complex_srcs
>=
offset
).
nonzero
()[
0
],
(
complex_dsts
<
offset
).
nonzero
()[
0
],
assume_unique
=
True
)
protein_ligand_graph
=
bipartite
(
((
complex_srcs
[
protein_ligand_indices
]
-
offset
).
tolist
(),
complex_dsts
[
protein_ligand_indices
].
tolist
()),
'protein_atom'
,
'complex'
,
'ligand_atom'
,
(
num_protein_atoms
,
num_ligand_atoms
))
protein_ligand_graph
.
edata
[
'distance'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
complex_dists
[
protein_ligand_indices
].
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Merge the graphs
g
=
hetero_from_relations
(
[
protein_graph
,
ligand_graph
,
inter_ligand_graph
,
inter_protein_graph
,
ligand_protein_graph
,
protein_ligand_graph
]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers
=
np
.
asarray
(
get_atomic_numbers
(
ligand_mol
,
ligand_atom_indices_left
))
# zero padding
ligand_atomic_numbers
=
np
.
concatenate
([
ligand_atomic_numbers
,
np
.
zeros
(
num_ligand_atoms
-
len
(
ligand_atom_indices_left
))])
protein_atomic_numbers
=
np
.
asarray
(
get_atomic_numbers
(
protein_mol
,
protein_atom_indices_left
))
# zero padding
protein_atomic_numbers
=
np
.
concatenate
([
protein_atomic_numbers
,
np
.
zeros
(
num_protein_atoms
-
len
(
protein_atom_indices_left
))])
g
.
nodes
[
'ligand_atom'
].
data
[
'atomic_number'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
ligand_atomic_numbers
.
astype
(
np
.
float32
)),
(
-
1
,
1
))
g
.
nodes
[
'protein_atom'
].
data
[
'atomic_number'
]
=
F
.
reshape
(
F
.
zerocopy_from_numpy
(
protein_atomic_numbers
.
astype
(
np
.
float32
)),
(
-
1
,
1
))
# Prepare mask indicating the existence of nodes
ligand_masks
=
np
.
zeros
((
num_ligand_atoms
,
1
))
ligand_masks
[:
len
(
ligand_atom_indices_left
),
:]
=
1
g
.
nodes
[
'ligand_atom'
].
data
[
'mask'
]
=
F
.
zerocopy_from_numpy
(
ligand_masks
.
astype
(
np
.
float32
))
protein_masks
=
np
.
zeros
((
num_protein_atoms
,
1
))
protein_masks
[:
len
(
protein_atom_indices_left
),
:]
=
1
g
.
nodes
[
'protein_atom'
].
data
[
'mask'
]
=
F
.
zerocopy_from_numpy
(
protein_masks
.
astype
(
np
.
float32
))
return
g
python/dgl/data/chem/utils/featurizers.py
deleted
100644 → 0
View file @
94c67203
import
itertools
import
numpy
as
np
from
collections
import
defaultdict
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
try
:
from
rdkit
import
Chem
from
rdkit.Chem
import
rdmolfiles
,
rdmolops
except
ImportError
:
pass
__all__
=
[
'one_hot_encoding'
,
'atom_type_one_hot'
,
'atomic_number_one_hot'
,
'atomic_number'
,
'atom_degree_one_hot'
,
'atom_degree'
,
'atom_total_degree_one_hot'
,
'atom_total_degree'
,
'atom_implicit_valence_one_hot'
,
'atom_implicit_valence'
,
'atom_hybridization_one_hot'
,
'atom_total_num_H_one_hot'
,
'atom_total_num_H'
,
'atom_formal_charge_one_hot'
,
'atom_formal_charge'
,
'atom_num_radical_electrons_one_hot'
,
'atom_num_radical_electrons'
,
'atom_is_aromatic_one_hot'
,
'atom_is_aromatic'
,
'atom_chiral_tag_one_hot'
,
'atom_mass'
,
'ConcatFeaturizer'
,
'BaseAtomFeaturizer'
,
'CanonicalAtomFeaturizer'
,
'bond_type_one_hot'
,
'bond_is_conjugated_one_hot'
,
'bond_is_conjugated'
,
'bond_is_in_ring_one_hot'
,
'bond_is_in_ring'
,
'bond_stereo_one_hot'
,
'BaseBondFeaturizer'
,
'CanonicalBondFeaturizer'
]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
one_hot_encoding
(
x
,
allowable_set
,
encode_unknown
=
False
):
"""One-hot encoding.
Parameters
----------
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
"""
if
encode_unknown
and
(
allowable_set
[
-
1
]
is
not
None
):
allowable_set
.
append
(
None
)
if
encode_unknown
and
(
x
not
in
allowable_set
):
x
=
None
return
list
(
map
(
lambda
s
:
x
==
s
,
allowable_set
))
#################################################################
# Atom featurization
#################################################################
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_type_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
]
return
one_hot_encoding
(
atom
.
GetSymbol
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atomic_number_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
1
,
101
))
return
one_hot_encoding
(
atom
.
GetAtomicNum
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atomic_number
(
atom
):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetAtomicNum
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_total_degree_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
11
))
return
one_hot_encoding
(
atom
.
GetDegree
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_degree
(
atom
):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree
"""
return
[
atom
.
GetDegree
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_total_degree_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
atom_degree_one_hot
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
6
))
return
one_hot_encoding
(
atom
.
GetTotalDegree
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_total_degree
(
atom
):
"""The degree of an atom including Hs.
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetTotalDegree
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_implicit_valence_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the implicit valences of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
7
))
return
one_hot_encoding
(
atom
.
GetImplicitValence
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_implicit_valence
(
atom
):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
"""
return
[
atom
.
GetImplicitValence
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_hybridization_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
,
Chem
.
rdchem
.
HybridizationType
.
SP3D
,
Chem
.
rdchem
.
HybridizationType
.
SP3D2
]
return
one_hot_encoding
(
atom
.
GetHybridization
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_total_num_H_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetTotalNumHs
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_total_num_H
(
atom
):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetTotalNumHs
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_formal_charge_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
-
2
,
3
))
return
one_hot_encoding
(
atom
.
GetFormalCharge
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_formal_charge
(
atom
):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetFormalCharge
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_num_radical_electrons_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
list
(
range
(
5
))
return
one_hot_encoding
(
atom
.
GetNumRadicalElectrons
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_num_radical_electrons
(
atom
):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return
[
atom
.
GetNumRadicalElectrons
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_is_aromatic_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
atom
.
GetIsAromatic
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_is_aromatic
(
atom
):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return
[
atom
.
GetIsAromatic
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_chiral_tag_one_hot
(
atom
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CW
,
Chem
.
rdchem
.
ChiralType
.
CHI_TETRAHEDRAL_CCW
,
Chem
.
rdchem
.
ChiralType
.
CHI_OTHER
]
return
one_hot_encoding
(
atom
.
GetChiralTag
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
atom_mass
(
atom
,
coef
=
0.01
):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return
[
atom
.
GetMass
()
*
coef
]
class
ConcatFeaturizer
(
object
):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
@
deprecated
(
'Import ConcatFeaturizer from dgllife.utils instead.'
,
'class'
)
def
__init__
(
self
,
func_list
):
self
.
func_list
=
func_list
def
__call__
(
self
,
x
):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return
list
(
itertools
.
chain
.
from_iterable
(
[
func
(
x
)
for
func
in
self
.
func_list
]))
class
BaseAtomFeaturizer
(
object
):
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
@
deprecated
(
'Import BaseAtomFeaturizer from dgllife.utils instead.'
,
'class'
)
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
if
feat_name
not
in
self
.
_feat_sizes
:
atom
=
Chem
.
MolFromSmiles
(
'C'
).
GetAtomWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
atom
))
return
self
.
_feat_sizes
[
feat_name
]
def
__call__
(
self
,
mol
):
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms
=
mol
.
GetNumAtoms
()
atom_features
=
defaultdict
(
list
)
# Compute features for each atom
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
atom_features
[
feat_name
].
append
(
feat_func
(
atom
))
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
atom_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
class
CanonicalAtomFeaturizer
(
BaseAtomFeaturizer
):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
@
deprecated
(
'Import CanonicalAtomFeaturizer from dgllife.utils instead.'
,
'class'
)
def
__init__
(
self
,
atom_data_field
=
'h'
):
super
(
CanonicalAtomFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
atom_data_field
:
ConcatFeaturizer
(
[
atom_type_one_hot
,
atom_degree_one_hot
,
atom_implicit_valence_one_hot
,
atom_formal_charge
,
atom_num_radical_electrons
,
atom_hybridization_one_hot
,
atom_is_aromatic
,
atom_total_num_H_one_hot
]
)})
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_type_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the type of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondType
.
SINGLE
,
Chem
.
rdchem
.
BondType
.
DOUBLE
,
Chem
.
rdchem
.
BondType
.
TRIPLE
,
Chem
.
rdchem
.
BondType
.
AROMATIC
]
return
one_hot_encoding
(
bond
.
GetBondType
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_is_conjugated_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
GetIsConjugated
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_is_conjugated
(
bond
):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return
[
bond
.
GetIsConjugated
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_is_in_ring_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
False
,
True
]
return
one_hot_encoding
(
bond
.
IsInRing
(),
allowable_set
,
encode_unknown
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_is_in_ring
(
bond
):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return
[
bond
.
IsInRing
()]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
bond_stereo_one_hot
(
bond
,
allowable_set
=
None
,
encode_unknown
=
False
):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if
allowable_set
is
None
:
allowable_set
=
[
Chem
.
rdchem
.
BondStereo
.
STEREONONE
,
Chem
.
rdchem
.
BondStereo
.
STEREOANY
,
Chem
.
rdchem
.
BondStereo
.
STEREOZ
,
Chem
.
rdchem
.
BondStereo
.
STEREOE
,
Chem
.
rdchem
.
BondStereo
.
STEREOCIS
,
Chem
.
rdchem
.
BondStereo
.
STEREOTRANS
]
return
one_hot_encoding
(
bond
.
GetStereo
(),
allowable_set
,
encode_unknown
)
class
BaseBondFeaturizer
(
object
):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
@
deprecated
(
'Import BaseBondFeaturizer from dgllife.utils instead.'
,
'class'
)
def
__init__
(
self
,
featurizer_funcs
,
feat_sizes
=
None
):
self
.
featurizer_funcs
=
featurizer_funcs
if
feat_sizes
is
None
:
feat_sizes
=
dict
()
self
.
_feat_sizes
=
feat_sizes
def
feat_size
(
self
,
feat_name
):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if
feat_name
not
in
self
.
featurizer_funcs
:
return
ValueError
(
'Expect feat_name to be in {}, got {}'
.
format
(
list
(
self
.
featurizer_funcs
.
keys
()),
feat_name
))
if
feat_name
not
in
self
.
_feat_sizes
:
bond
=
Chem
.
MolFromSmiles
(
'CO'
).
GetBondWithIdx
(
0
)
self
.
_feat_sizes
[
feat_name
]
=
len
(
self
.
featurizer_funcs
[
feat_name
](
bond
))
return
self
.
_feat_sizes
[
feat_name
]
def
__call__
(
self
,
mol
):
"""Featurize all bonds in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_bonds
=
mol
.
GetNumBonds
()
bond_features
=
defaultdict
(
list
)
# Compute features for each bond
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
for
feat_name
,
feat_func
in
self
.
featurizer_funcs
.
items
():
feat
=
feat_func
(
bond
)
bond_features
[
feat_name
].
extend
([
feat
,
feat
.
copy
()])
# Stack the features and convert them to float arrays
processed_features
=
dict
()
for
feat_name
,
feat_list
in
bond_features
.
items
():
feat
=
np
.
stack
(
feat_list
)
processed_features
[
feat_name
]
=
F
.
zerocopy_from_numpy
(
feat
.
astype
(
np
.
float32
))
return
processed_features
class
CanonicalBondFeaturizer
(
BaseBondFeaturizer
):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
@
deprecated
(
'Import CanonicalBondFeaturizer from dgllife.utils instead.'
,
'class'
)
def
__init__
(
self
,
bond_data_field
=
'e'
):
super
(
CanonicalBondFeaturizer
,
self
).
__init__
(
featurizer_funcs
=
{
bond_data_field
:
ConcatFeaturizer
(
[
bond_type_one_hot
,
bond_is_conjugated
,
bond_is_in_ring
,
bond_stereo_one_hot
]
)})
python/dgl/data/chem/utils/mol_to_graph.py
deleted
100644 → 0
View file @
94c67203
"""Convert molecules into DGLGraphs."""
import
numpy
as
np
from
functools
import
partial
from
....
import
DGLGraph
from
....contrib.deprecation
import
deprecated
try
:
import
mdtraj
from
rdkit
import
Chem
from
rdkit.Chem
import
rdmolfiles
,
rdmolops
except
ImportError
:
pass
__all__
=
[
'mol_to_graph'
,
'smiles_to_bigraph'
,
'mol_to_bigraph'
,
'smiles_to_complete_graph'
,
'mol_to_complete_graph'
,
'k_nearest_neighbors'
]
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
mol_to_graph
(
mol
,
graph_constructor
,
node_featurizer
,
edge_featurizer
):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
mol
=
rdmolops
.
RenumberAtoms
(
mol
,
new_order
)
g
=
graph_constructor
(
mol
)
if
node_featurizer
is
not
None
:
g
.
ndata
.
update
(
node_featurizer
(
mol
))
if
edge_featurizer
is
not
None
:
g
.
edata
.
update
(
edge_featurizer
(
mol
))
return
g
def
construct_bigraph_from_mol
(
mol
,
add_self_loop
=
False
):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g
=
DGLGraph
()
# Add nodes
num_atoms
=
mol
.
GetNumAtoms
()
g
.
add_nodes
(
num_atoms
)
# Add edges
src_list
=
[]
dst_list
=
[]
num_bonds
=
mol
.
GetNumBonds
()
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
src_list
.
extend
([
u
,
v
])
dst_list
.
extend
([
v
,
u
])
g
.
add_edges
(
src_list
,
dst_list
)
if
add_self_loop
:
nodes
=
g
.
nodes
()
g
.
add_edges
(
nodes
,
nodes
)
return
g
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
mol_to_bigraph
(
mol
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return
mol_to_graph
(
mol
,
partial
(
construct_bigraph_from_mol
,
add_self_loop
=
add_self_loop
),
node_featurizer
,
edge_featurizer
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
smiles_to_bigraph
(
smiles
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_bigraph
(
mol
,
add_self_loop
,
node_featurizer
,
edge_featurizer
)
def
construct_complete_graph_from_mol
(
mol
,
add_self_loop
=
False
):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g
=
DGLGraph
()
num_atoms
=
mol
.
GetNumAtoms
()
g
.
add_nodes
(
num_atoms
)
if
add_self_loop
:
g
.
add_edges
(
[
i
for
i
in
range
(
num_atoms
)
for
j
in
range
(
num_atoms
)],
[
j
for
i
in
range
(
num_atoms
)
for
j
in
range
(
num_atoms
)])
else
:
g
.
add_edges
(
[
i
for
i
in
range
(
num_atoms
)
for
j
in
range
(
num_atoms
-
1
)],
[
j
for
i
in
range
(
num_atoms
)
for
j
in
range
(
num_atoms
)
if
i
!=
j
])
return
g
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
mol_to_complete_graph
(
mol
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return
mol_to_graph
(
mol
,
partial
(
construct_complete_graph_from_mol
,
add_self_loop
=
add_self_loop
),
node_featurizer
,
edge_featurizer
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
smiles_to_complete_graph
(
smiles
,
add_self_loop
=
False
,
node_featurizer
=
None
,
edge_featurizer
=
None
):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
mol_to_complete_graph
(
mol
,
add_self_loop
,
node_featurizer
,
edge_featurizer
)
@
deprecated
(
'Import it from dgllife.utils instead.'
)
def
k_nearest_neighbors
(
coordinates
,
neighbor_cutoff
,
max_num_neighbors
):
"""Find k nearest neighbors for each atom based on the 3D coordinates.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
Returns
-------
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
distances : list of float
Distances between the end nodes.
"""
num_atoms
=
coordinates
.
shape
[
0
]
traj
=
mdtraj
.
Trajectory
(
coordinates
.
reshape
((
1
,
num_atoms
,
3
)),
None
)
neighbors
=
mdtraj
.
geometry
.
compute_neighborlist
(
traj
,
neighbor_cutoff
)
srcs
,
dsts
,
distances
=
[],
[],
[]
for
i
in
range
(
num_atoms
):
delta
=
coordinates
[
i
]
-
coordinates
.
take
(
neighbors
[
i
],
axis
=
0
)
dist
=
np
.
linalg
.
norm
(
delta
,
axis
=
1
)
if
max_num_neighbors
is
not
None
and
len
(
neighbors
[
i
])
>
max_num_neighbors
:
sorted_neighbors
=
list
(
zip
(
dist
,
neighbors
[
i
]))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors
.
sort
(
key
=
lambda
tup
:
tup
[
0
])
dsts
.
extend
([
i
for
_
in
range
(
max_num_neighbors
)])
srcs
.
extend
([
int
(
sorted_neighbors
[
j
][
1
])
for
j
in
range
(
max_num_neighbors
)])
distances
.
extend
([
float
(
sorted_neighbors
[
j
][
0
])
for
j
in
range
(
max_num_neighbors
)])
else
:
dsts
.
extend
([
i
for
_
in
range
(
len
(
neighbors
[
i
]))])
srcs
.
extend
(
neighbors
[
i
].
tolist
())
distances
.
extend
(
dist
.
tolist
())
return
srcs
,
dsts
,
distances
python/dgl/data/chem/utils/rdkit_utils.py
deleted
100644 → 0
View file @
94c67203
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
import
warnings
from
functools
import
partial
from
multiprocessing
import
Pool
from
....contrib.deprecation
import
deprecated
try
:
import
pdbfixer
import
simtk
from
rdkit
import
Chem
from
rdkit.Chem
import
AllChem
from
StringIO
import
StringIO
except
ImportError
:
from
io
import
StringIO
__all__
=
[
'add_hydrogens_to_mol'
,
'get_mol_3D_coordinates'
,
'load_molecule'
,
'multiprocess_load_molecules'
]
@
deprecated
(
''
)
def
add_hydrogens_to_mol
(
mol
):
"""Add hydrogens to an RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance with hydrogens added. For failures in adding hydrogens,
the original RDKit molecule instance will be returned.
"""
try
:
pdbblock
=
Chem
.
MolToPDBBlock
(
mol
)
pdb_stringio
=
StringIO
()
pdb_stringio
.
write
(
pdbblock
)
pdb_stringio
.
seek
(
0
)
fixer
=
pdbfixer
.
PDBFixer
(
pdbfile
=
pdb_stringio
)
fixer
.
findMissingResidues
()
fixer
.
findMissingAtoms
()
fixer
.
addMissingAtoms
()
fixer
.
addMissingHydrogens
(
7.4
)
hydrogenated_io
=
StringIO
()
simtk
.
openmm
.
app
.
PDBFile
.
writeFile
(
fixer
.
topology
,
fixer
.
positions
,
hydrogenated_io
)
hydrogenated_io
.
seek
(
0
)
mol
=
Chem
.
MolFromPDBBlock
(
hydrogenated_io
.
read
(),
sanitize
=
False
,
removeHs
=
False
)
pdb_stringio
.
close
()
hydrogenated_io
.
close
()
except
ValueError
:
warnings
.
warn
(
'Failed to add hydrogens to the molecule.'
)
return
mol
@
deprecated
(
'Import it from dgllife.utils.rdkit_utils instead.'
)
def
get_mol_3D_coordinates
(
mol
):
"""Get 3D coordinates of the molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
"""
try
:
conf
=
mol
.
GetConformer
()
conf_num_atoms
=
conf
.
GetNumAtoms
()
mol_num_atoms
=
mol
.
GetNumAtoms
()
assert
mol_num_atoms
==
conf_num_atoms
,
\
'Expect the number of atoms in the molecule and its conformation '
\
'to be the same, got {:d} and {:d}'
.
format
(
mol_num_atoms
,
conf_num_atoms
)
return
conf
.
GetPositions
()
except
:
warnings
.
warn
(
'Unable to get conformation of the molecule.'
)
return
None
@
deprecated
(
'Import it from dgllife.utils.rdkit_utils instead.'
)
def
load_molecule
(
molecule_file
,
add_hydrogens
=
False
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
):
"""Load a molecule from a file.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if
molecule_file
.
endswith
(
'.mol2'
):
mol
=
Chem
.
MolFromMol2File
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
elif
molecule_file
.
endswith
(
'.sdf'
):
supplier
=
Chem
.
SDMolSupplier
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
mol
=
supplier
[
0
]
elif
molecule_file
.
endswith
(
'.pdbqt'
):
with
open
(
molecule_file
)
as
f
:
pdbqt_data
=
f
.
readlines
()
pdb_block
=
''
for
line
in
pdbqt_data
:
pdb_block
+=
'{}
\n
'
.
format
(
line
[:
66
])
mol
=
Chem
.
MolFromPDBBlock
(
pdb_block
,
sanitize
=
False
,
removeHs
=
False
)
elif
molecule_file
.
endswith
(
'.pdb'
):
mol
=
Chem
.
MolFromPDBFile
(
molecule_file
,
sanitize
=
False
,
removeHs
=
False
)
else
:
return
ValueError
(
'Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'
.
format
(
molecule_file
))
try
:
if
add_hydrogens
or
calc_charges
:
mol
=
add_hydrogens_to_mol
(
mol
)
if
sanitize
or
calc_charges
:
Chem
.
SanitizeMol
(
mol
)
if
calc_charges
:
# Compute Gasteiger charges on the molecule.
try
:
AllChem
.
ComputeGasteigerCharges
(
mol
)
except
:
warnings
.
warn
(
'Unable to compute charges for the molecule.'
)
if
remove_hs
:
mol
=
Chem
.
RemoveHs
(
mol
)
except
:
return
None
,
None
if
use_conformation
:
coordinates
=
get_mol_3D_coordinates
(
mol
)
else
:
coordinates
=
None
return
mol
,
coordinates
@
deprecated
(
'Import it from dgllife.utils.rdkit_utils instead.'
)
def
multiprocess_load_molecules
(
files
,
add_hydrogens
=
False
,
sanitize
=
False
,
calc_charges
=
False
,
remove_hs
=
False
,
use_conformation
=
True
,
num_processes
=
2
):
"""Load molecules from files with multiprocessing.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if
num_processes
==
1
:
mols_loaded
=
[]
for
i
,
f
in
enumerate
(
files
):
mols_loaded
.
append
(
load_molecule
(
f
,
add_hydrogens
=
add_hydrogens
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
))
else
:
with
Pool
(
processes
=
num_processes
)
as
pool
:
mols_loaded
=
pool
.
map_async
(
partial
(
load_molecule
,
add_hydrogens
=
add_hydrogens
,
sanitize
=
sanitize
,
calc_charges
=
calc_charges
,
remove_hs
=
remove_hs
,
use_conformation
=
use_conformation
),
files
)
mols_loaded
=
mols_loaded
.
get
()
return
mols_loaded
python/dgl/data/chem/utils/splitters.py
deleted
100644 → 0
View file @
94c67203
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
import
numpy
as
np
from
collections
import
defaultdict
from
functools
import
partial
from
itertools
import
accumulate
,
chain
from
...utils
import
split_dataset
,
Subset
from
....
import
backend
as
F
from
....contrib.deprecation
import
deprecated
try
:
from
rdkit
import
Chem
from
rdkit.Chem
import
rdMolDescriptors
from
rdkit.Chem.rdmolops
import
FastFindRings
from
rdkit.Chem.Scaffolds
import
MurckoScaffold
except
ImportError
:
pass
__all__
=
[
'ConsecutiveSplitter'
,
'RandomSplitter'
,
'MolecularWeightSplitter'
,
'ScaffoldSplitter'
,
'SingleTaskStratifiedSplitter'
]
def
base_k_fold_split
(
split_method
,
dataset
,
k
,
log
):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
all_folds
=
[]
frac_per_part
=
1.
/
k
for
i
in
range
(
k
):
if
log
:
print
(
'Processing fold {:d}/{:d}'
.
format
(
i
+
1
,
k
))
# We are reusing the code for train-validation-test split.
train_set1
,
val_set
,
train_set2
=
split_method
(
dataset
,
frac_train
=
i
*
frac_per_part
,
frac_val
=
frac_per_part
,
frac_test
=
1.
-
(
i
+
1
)
*
frac_per_part
)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set
=
Subset
(
dataset
,
train_set1
.
indices
+
train_set2
.
indices
)
all_folds
.
append
((
train_set
,
val_set
))
return
all_folds
def
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction
=
frac_train
+
frac_val
+
frac_test
assert
np
.
allclose
(
total_fraction
,
1.
),
\
'Expect the sum of fractions for training, validation and '
\
'test to be 1, got {:.4f}'
.
format
(
total_fraction
)
def
indices_split
(
dataset
,
frac_train
,
frac_val
,
frac_test
,
indices
):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list
=
np
.
asarray
([
frac_train
,
frac_val
,
frac_test
])
assert
np
.
allclose
(
np
.
sum
(
frac_list
),
1.
),
\
'Expect frac_list sum to 1, got {:.4f}'
.
format
(
np
.
sum
(
frac_list
))
num_data
=
len
(
dataset
)
lengths
=
(
num_data
*
frac_list
).
astype
(
int
)
lengths
[
-
1
]
=
num_data
-
np
.
sum
(
lengths
[:
-
1
])
return
[
Subset
(
dataset
,
list
(
indices
[
offset
-
length
:
offset
]))
for
offset
,
length
in
zip
(
accumulate
(
lengths
),
lengths
)]
def
count_and_log
(
message
,
i
,
total
,
log_every_n
):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if
(
log_every_n
is
not
None
)
and
((
i
+
1
)
%
log_every_n
==
0
):
print
(
'{} {:d}/{:d}'
.
format
(
message
,
i
+
1
,
total
))
def
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
=
1000
):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if
mols
is
not
None
:
# Sanity check
assert
len
(
mols
)
==
len
(
dataset
),
\
'Expect mols to be of the same size as that of the dataset, '
\
'got {:d} and {:d}'
.
format
(
len
(
mols
),
len
(
dataset
))
else
:
if
log_every_n
is
not
None
:
print
(
'Start initializing RDKit molecule instances...'
)
mols
=
[]
for
i
,
s
in
enumerate
(
dataset
.
smiles
):
count_and_log
(
'Creating RDKit molecule instance'
,
i
,
len
(
dataset
.
smiles
),
log_every_n
)
mols
.
append
(
Chem
.
MolFromSmiles
(
s
,
sanitize
=
sanitize
))
return
mols
class
ConsecutiveSplitter
(
object
):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@
staticmethod
@
deprecated
(
'Import ConsecutiveSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
train_val_test_split
(
dataset
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
False
)
@
staticmethod
@
deprecated
(
'Import ConsecutiveSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
k_fold_split
(
dataset
,
k
=
5
,
log
=
True
):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
return
base_k_fold_split
(
ConsecutiveSplitter
.
train_val_test_split
,
dataset
,
k
,
log
)
class
RandomSplitter
(
object
):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@
staticmethod
@
deprecated
(
'Import RandomSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
train_val_test_split
(
dataset
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
random_state
=
None
):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
return
split_dataset
(
dataset
,
frac_list
=
[
frac_train
,
frac_val
,
frac_test
],
shuffle
=
True
,
random_state
=
random_state
)
@
staticmethod
@
deprecated
(
'Import RandomSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
k_fold_split
(
dataset
,
k
=
5
,
random_state
=
None
,
log
=
True
):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices
=
np
.
random
.
RandomState
(
seed
=
random_state
).
permutation
(
len
(
dataset
))
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
indices
),
dataset
,
k
,
log
)
class
MolecularWeightSplitter
(
object
):
"""Sort molecules based on their weights and then split them."""
@
staticmethod
@
deprecated
(
'Import MolecularWeightSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
molecular_weight_indices
(
molecules
,
log_every_n
):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if
log_every_n
is
not
None
:
print
(
'Start computing molecular weights.'
)
mws
=
[]
for
i
,
mol
in
enumerate
(
molecules
):
count_and_log
(
'Computing molecular weight for compound'
,
i
,
len
(
molecules
),
log_every_n
)
mws
.
append
(
Chem
.
rdMolDescriptors
.
CalcExactMolWt
(
mol
))
return
np
.
argsort
(
mws
)
@
staticmethod
@
deprecated
(
'Import MolecularWeightSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
train_val_test_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
log_every_n
=
1000
):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
return
indices_split
(
dataset
,
frac_train
,
frac_val
,
frac_test
,
sorted_indices
)
@
staticmethod
@
deprecated
(
'Import MolecularWeightSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
k_fold_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
k
=
5
,
log_every_n
=
1000
):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
,
log_every_n
)
sorted_indices
=
MolecularWeightSplitter
.
molecular_weight_indices
(
molecules
,
log_every_n
)
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
sorted_indices
),
dataset
,
k
,
log
=
(
log_every_n
is
not
None
))
class
ScaffoldSplitter
(
object
):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@
staticmethod
@
deprecated
(
'Import ScaffoldSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if
log_every_n
is
not
None
:
print
(
'Start computing Bemis-Murcko scaffolds.'
)
scaffolds
=
defaultdict
(
list
)
for
i
,
mol
in
enumerate
(
molecules
):
count_and_log
(
'Computing Bemis-Murcko for compound'
,
i
,
len
(
molecules
),
log_every_n
)
# For mols that have not been sanitized, we need to compute their ring information
try
:
FastFindRings
(
mol
)
mol_scaffold
=
MurckoScaffold
.
MurckoScaffoldSmiles
(
mol
=
mol
,
includeChirality
=
include_chirality
)
# Group molecules that have the same scaffold
scaffolds
[
mol_scaffold
].
append
(
i
)
except
:
print
(
'Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'
.
format
(
i
+
1
))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets
=
[
scaffold_set
for
(
scaffold
,
scaffold_set
)
in
sorted
(
scaffolds
.
items
(),
key
=
lambda
x
:
(
len
(
x
[
1
]),
x
[
1
][
0
]),
reverse
=
True
)
]
return
scaffold_sets
@
staticmethod
@
deprecated
(
'Import ScaffoldSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
train_val_test_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
include_chirality
=
False
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
log_every_n
=
1000
):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
)
scaffold_sets
=
ScaffoldSplitter
.
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
)
train_indices
,
val_indices
,
test_indices
=
[],
[],
[]
train_cutoff
=
int
(
frac_train
*
len
(
molecules
))
val_cutoff
=
int
((
frac_train
+
frac_val
)
*
len
(
molecules
))
for
group_indices
in
scaffold_sets
:
if
len
(
train_indices
)
+
len
(
group_indices
)
>
train_cutoff
:
if
len
(
train_indices
)
+
len
(
val_indices
)
+
len
(
group_indices
)
>
val_cutoff
:
test_indices
.
extend
(
group_indices
)
else
:
val_indices
.
extend
(
group_indices
)
else
:
train_indices
.
extend
(
group_indices
)
return
[
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
),
Subset
(
dataset
,
test_indices
)]
@
staticmethod
@
deprecated
(
'Import ScaffoldSplitter from dgllife.utils.splitters instead.'
,
'class'
)
def
k_fold_split
(
dataset
,
mols
=
None
,
sanitize
=
True
,
include_chirality
=
False
,
k
=
5
,
log_every_n
=
1000
):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert
k
>=
2
,
'Expect the number of folds to be no smaller than 2, got {:d}'
.
format
(
k
)
molecules
=
prepare_mols
(
dataset
,
mols
,
sanitize
)
scaffold_sets
=
ScaffoldSplitter
.
get_ordered_scaffold_sets
(
molecules
,
include_chirality
,
log_every_n
)
# k buckets that form a relatively balanced partition of the dataset
index_buckets
=
[[]
for
_
in
range
(
k
)]
for
group_indices
in
scaffold_sets
:
bucket_chosen
=
int
(
np
.
argmin
([
len
(
bucket
)
for
bucket
in
index_buckets
]))
index_buckets
[
bucket_chosen
].
extend
(
group_indices
)
all_folds
=
[]
for
i
in
range
(
k
):
if
log_every_n
is
not
None
:
print
(
'Processing fold {:d}/{:d}'
.
format
(
i
+
1
,
k
))
train_indices
=
list
(
chain
.
from_iterable
(
index_buckets
[:
i
]
+
index_buckets
[
i
+
1
:]))
val_indices
=
index_buckets
[
i
]
all_folds
.
append
((
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
)))
return
all_folds
class
SingleTaskStratifiedSplitter
(
object
):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@
staticmethod
@
deprecated
(
'Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.'
,
'class'
)
def
train_val_test_split
(
dataset
,
labels
,
task_id
,
frac_train
=
0.8
,
frac_val
=
0.1
,
frac_test
=
0.1
,
bucket_size
=
10
,
random_state
=
None
):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
train_val_test_sanity_check
(
frac_train
,
frac_val
,
frac_test
)
if
random_state
is
not
None
:
np
.
random
.
seed
(
random_state
)
if
not
isinstance
(
labels
,
np
.
ndarray
):
labels
=
F
.
asnumpy
(
labels
)
task_labels
=
labels
[:,
task_id
]
sorted_indices
=
np
.
argsort
(
task_labels
)
train_bucket_cutoff
=
int
(
np
.
round
(
frac_train
*
bucket_size
))
val_bucket_cutoff
=
int
(
np
.
round
(
frac_val
*
bucket_size
))
+
train_bucket_cutoff
train_indices
,
val_indices
,
test_indices
=
[],
[],
[]
while
sorted_indices
.
shape
[
0
]
>=
bucket_size
:
current_batch
,
sorted_indices
=
np
.
split
(
sorted_indices
,
[
bucket_size
])
shuffled
=
np
.
random
.
permutation
(
range
(
bucket_size
))
train_indices
.
extend
(
current_batch
[
shuffled
[:
train_bucket_cutoff
]].
tolist
())
val_indices
.
extend
(
current_batch
[
shuffled
[
train_bucket_cutoff
:
val_bucket_cutoff
]].
tolist
())
test_indices
.
extend
(
current_batch
[
shuffled
[
val_bucket_cutoff
:]].
tolist
())
# Place rest samples in the training set.
train_indices
.
extend
(
sorted_indices
.
tolist
())
return
[
Subset
(
dataset
,
train_indices
),
Subset
(
dataset
,
val_indices
),
Subset
(
dataset
,
test_indices
)]
@
staticmethod
@
deprecated
(
'Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.'
,
'class'
)
def
k_fold_split
(
dataset
,
labels
,
task_id
,
k
=
5
,
log
=
True
):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
if
not
isinstance
(
labels
,
np
.
ndarray
):
labels
=
F
.
asnumpy
(
labels
)
task_labels
=
labels
[:,
task_id
]
sorted_indices
=
np
.
argsort
(
task_labels
).
tolist
()
return
base_k_fold_split
(
partial
(
indices_split
,
indices
=
sorted_indices
),
dataset
,
k
,
log
)
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment